xref: /aosp_15_r20/external/executorch/exir/backend/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import operator
9from collections import defaultdict
10from functools import lru_cache
11from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
12
13import torch
14from executorch.exir.backend.backend_details import ExportedProgram
15from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
16    duplicate_constant_node,
17)
18from executorch.exir.common import setting_python_recursive_limit
19from executorch.exir.delegate import executorch_call_delegate
20from executorch.exir.dialects._ops import ops as exir_ops
21
22from executorch.exir.lowered_backend_module import create_submodule_from_nodes
23from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
24from torch.fx.node import Node
25from torch.fx.passes.utils.source_matcher_utils import SourcePartition
26
27T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
28T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
29
30
31# NB: Set this to None to handle validation from MobileBert
32@lru_cache(maxsize=None)
33def is_same_node(
34    node_left: Iterable[torch.fx.Node],
35    node_right: Iterable[torch.fx.Node],
36) -> bool:
37    # two nodes are the same if they have the same target and op
38    # same for their args
39    if isinstance(node_left, torch.fx.Node) and isinstance(node_right, torch.fx.Node):
40        if not (
41            (node_left.target == node_right.target)
42            and (node_left.op == node_right.op)
43            and (len(node_left.all_input_nodes) == len(node_right.all_input_nodes))
44            and all(
45                is_same_node(arg_left, arg_right)
46                for arg_left, arg_right in zip(
47                    node_left.all_input_nodes, node_right.all_input_nodes
48                )
49            )
50        ):
51            return False
52    else:
53        if len(list(node_left)) != len(list(node_right)):
54            return False
55        for n_left, n_right in zip(node_left, node_right):
56            if not is_same_node(n_left, n_right):
57                return False
58    return True
59
60
61def is_identical_graph(
62    graph_left: torch.fx.GraphModule, graph_right: torch.fx.GraphModule
63) -> bool:
64    # two graph are the same if they have the same nodes and op. The order of nodes also
65    # matters in this function is more strict. Two graph are not considered as the same
66    # if the topological order of the nodes is the same in this function but the order of nodes
67    # is not the same.
68    if len(list(graph_left.graph.nodes)) != len(list(graph_right.graph.nodes)):
69        return False
70    with setting_python_recursive_limit(30000):
71        for node_left, node_right in zip(
72            graph_left.graph.nodes, graph_right.graph.nodes
73        ):
74            if not (is_same_node(node_left, node_right)):
75                return False
76    return True
77
78
79def remove_first_quant_and_last_dequant(
80    graph_module: torch.fx.GraphModule,
81) -> None:
82    for node in graph_module.graph.nodes:
83        if node.target == T_QuantPerTensor:
84            if node.args[0].op == "placeholder":
85                node_users = list(node.users.keys())
86                for dequant_node in node_users:
87                    # point the dequant arg to the placeholder
88                    dequant_node.args = (node.args[0],) + dequant_node.args[1:]
89        elif node.target == T_DQuantPerTensor:
90            node_users = list(node.users.keys())
91            if node_users[0].op == "output":
92                # point the output arg to the quant node
93                output_node = node_users[0]
94                output_node.args = ([node.args[0]],)
95    # Remove the quant/dequant nodes as they don't have users
96    graph_module.graph.eliminate_dead_code()
97    graph_module.recompile()
98
99
100# TODO - use edge ops
101def replace_quantized_partition_with_op(
102    graph_module: torch.fx.GraphModule,
103    partition: SourcePartition,
104    replacement_op: torch._ops.OpOverloadPacket,
105) -> Tuple[torch.fx.Node, List[torch.fx.Node], List[torch.fx.Node]]:
106    """
107    Replaces partition with the op specified by replacement_op. It's also expected that
108    the nodes contained in partition are sourced from a quantized module as this function
109    searches for the quantization pattern to consume along with the nodes in the partition,
110    to be then replaced by replacement_op.
111
112    Args:
113        graph_module: The graph module from which this partition was sourced.
114        partition: Partition to be replaced.
115        replacement_op: The op to replace paritition with.
116    Returns:
117        Tuple: First element in the tuple is the new replaced module. The second and third
118        node lists in the returned tuple consist of the dq and q nodes that were consumed
119        along with this partition to be replaced by the replacement_op.
120    """
121
122    dequant_nodes = []
123    quant_nodes = []
124    input_nodes = []
125    output_nodes = []
126
127    partition_nodes = [node for node in partition.nodes if node not in partition.params]
128
129    # We recreate our input nodes and output nodes list instead of using partition.input_nodes
130    # and partition.output_nodes as the ordering of the nodes in those lists is not deterministic,
131    # whereas for the quant fusion pass we expect deterministic ordering.
132    for node in partition.nodes:
133        for arg in node.args:
134            if isinstance(arg, torch.fx.Node) and (arg not in partition.nodes):
135                input_nodes.append(arg)
136
137        for user in node.users.keys():
138            if user not in partition.nodes:
139                output_nodes.append(node)
140
141    # Try to find all the dq nodes that are feeding into this module.
142    for node in input_nodes:
143        if node.target == T_DQuantPerTensor:
144            dequant_nodes += [node]
145
146    # Try to find all the q nodes that this module is feeding out into.
147    for node in output_nodes:
148        for user in node.users.keys():
149            if user.target == T_QuantPerTensor:
150                quant_nodes += [user]
151
152    assert len(dequant_nodes) >= 1, "Dequant nodes missing in node list to be replaced."
153    assert len(quant_nodes) >= 1, "Quant nodes missing in node list to be replaced."
154
155    # After this, node list will essentially contain all the nodes in the
156    # dq->op->q pattern that we will want to replace with a custom backend op.
157    node_list = dequant_nodes + partition_nodes + quant_nodes
158
159    submodule, call_module_node = create_submodule_from_nodes(
160        graph_module, node_list, "to_be_replaced", skip_legalize_graph=True
161    )
162
163    # Update the replaced op so that we have all the latest args and kwargs.
164    with graph_module.graph.inserting_before(call_module_node):
165        replaced_op = graph_module.graph.call_function(
166            replacement_op,
167            call_module_node.args,
168            kwargs=call_module_node.kwargs,
169        )
170        call_module_node.replace_all_uses_with(replaced_op)
171        graph_module.graph.erase_node(call_module_node)
172        replaced_op.meta = call_module_node.meta
173    graph_module.recompile()
174
175    return (replaced_op, dequant_nodes, quant_nodes)
176
177
178def _assign_new_tag(
179    tagged_exported_program: ExportedProgram,
180    copied_nodes: Set[str],
181):
182    """
183    Assign new tag to the copied nodes.
184
185    Before the pass
186    constant_0 (tag_10) ------------------> op_b (tag_10)
187    constant_0_copy (tag_10) -------------> op_a (tag_11)
188
189    After the pass
190    constant_0 (tag_10) ------------------> op_b (tag_10)
191    constant_0_copy (tag_11) -------------> op_a (tag_11)
192
193    """
194    for node in tagged_exported_program.graph.nodes:
195        if node.op == "placeholder":
196            if node.name in copied_nodes:
197                users_tag = set()
198                for user in node.users:
199                    users_tag.add(user.meta.get("delegation_tag", None))
200                # Assign the tag to the copy constant node the same as their users.
201                if len(users_tag) == 1:
202                    node.meta["delegation_tag"] = users_tag.pop()
203
204
205def _maybe_duplicate_constant_nodes(
206    tagged_exported_program: ExportedProgram,
207    tag: str,
208) -> None:
209    """
210    If the constants node is shared by different tagged nodes, like
211    constant_0 ----> op_b (tag_10)
212    |-------------> op_a (tag_11)
213
214    we make default as constant_0 is duplicated to constant_0_1, constant_0_2, unless the node is tagged with "no_copy"
215    constant_0 ------------------> op_b (tag_10)
216    constant_0_copy -------------> op_a (tag_11)
217
218    backend can estimate how much they want to duplicate the constant node, either error out or default to duplicate
219    """
220    candidate_nodes = set()
221    for node in tagged_exported_program.graph.nodes:
222        if node.meta.get("delegation_tag", "") == tag:
223            if node.op == "placeholder":
224                for user in node.users:
225                    users_tag = user.meta.get("delegation_tag", None)
226                    if users_tag != tag:
227                        # If the node is tagged with "no_copy", we stop duplicating it and throw an error
228                        if node.meta.get("no_copy", False):
229                            raise RuntimeError(
230                                f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})"
231                            )
232                        else:
233                            candidate_nodes.add(node.name)
234    copied_nodes = set()
235    for candidate_node in candidate_nodes:
236        # Both tagged exported program and the owning program need to go through the same duplication pass
237        copied_nodes = copied_nodes.union(
238            duplicate_constant_node(tagged_exported_program, candidate_node)
239        )
240    candidate_node_with_copies = candidate_nodes.union(copied_nodes)
241    _assign_new_tag(tagged_exported_program, candidate_node_with_copies)
242
243
244def _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool:
245    """
246    Check if the node is the getitem followed by executorch_call_delegate node. These getitems node
247    are just for getting the result from delegate because the input/output to delegates are flattened
248    """
249    return (
250        node.target == operator.getitem
251        and len(node.args) == 2
252        and node.args[0].target == executorch_call_delegate  # pyre-ignore
253        and isinstance(node.args[1], int)
254    )
255
256
257def get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
258    """
259    Returns a list of non lowered nodes in the graph module.
260    """
261    return [
262        node
263        for node in graph.nodes
264        if node.op == "call_function"
265        and node.target != executorch_call_delegate
266        and (not _get_item_from_executorch_call_delegate(node))
267    ]
268
269
270def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
271    """
272    Returns the list of delegates from the graph.
273    """
274    return [
275        node
276        for node in graph.nodes
277        if node.op == "get_attr" and node.name.startswith("lowered_module_")
278    ]
279
280
281def print_delegated_graph(graph_module: torch.fx.GraphModule) -> None:
282    """
283    Print the formatted graph string.
284    """
285    print(format_delegated_graph(graph_module))
286
287
288def format_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
289    """
290    Return the formatted graph string of including lowered_module (both backend id and original graph) together with the graph module. Example output:
291    graph():
292        %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
293        %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
294        %arg2_1 : [num_users=2] = placeholder[target=arg2_1]
295        %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
296            backend_id: BackendWithCompilerDemo
297            lowered graph():
298                %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
299                %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
300                %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
301                %aten_mm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {})
302                %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default, %arg2_1), kwargs = {})
303                return [aten_add_tensor]
304        %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1, %arg1_1, %arg2_1), kwargs = {})
305        %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {})
306        %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%getitem, %arg0_1), kwargs = {})
307        %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1]
308            backend_id: BackendWithCompilerDemo
309            lowered graph():
310                %aten_sub_tensor : [num_users=1] = placeholder[target=aten_sub_tensor]
311                %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
312                %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
313                %aten_mm_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_sub_tensor, %arg1_1), kwargs = {})
314                %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default_1, %arg2_1), kwargs = {})
315                return [aten_add_tensor_1]
316        %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %aten_sub_tensor, %arg1_1, %arg2_1), kwargs = {})
317        %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {})
318        return [getitem_1]
319    """
320    lowered_module_dict = {
321        node.name: getattr(graph_module, node.name)
322        for node in graph_module.graph.nodes
323        if node.op == "get_attr" and node.name.startswith("lowered_module_")
324    }
325    indent = "  "
326    graph_format_str = "graph():\n"
327    for node in graph_module.graph.nodes:
328        graph_format_str += f"{indent}{node.format_node()}\n"
329        if node.op == "get_attr" and node.name.startswith("lowered_module_"):
330            lowered_module = lowered_module_dict[node.name]
331            graph_format_str += f"{indent * 2}backend_id: {lowered_module.backend_id}\n"
332            graph_format_str += f"{indent * 2}lowered graph():\n"
333            for node_in_lowered_module in lowered_module.original_module.graph.nodes:
334                graph_format_str += (
335                    f"{indent * 3}{node_in_lowered_module.format_node()}\n"
336                )
337    return graph_format_str
338
339
340def tag_constant_data(edge_program: ExportedProgram) -> None:
341    """
342    Util function for partitioners. This function tags the const/param/buffers nodes
343    whose users all belong within the same partition. This should be called after tagging all other nodes.
344    Any const/param/buffer which is used as input to a subgraph, will be tagged with the same tag as that
345    subgraph. Throw error when const/param/buffers is used across different partitions. That is the
346    underlying data will be owned by multiple delegates.
347    """
348    mutated_buffer = set()
349    for node in edge_program.graph.nodes:
350        if node.op == "placeholder" and (
351            is_param(edge_program, node)
352            or is_buffer(edge_program, node)
353            or is_lifted_tensor_constant(edge_program, node)
354        ):
355            for node_user in node.users:
356                if node_user.name in edge_program.graph_signature.buffers_to_mutate:
357                    logging.info(
358                        "The buffer node is a mutated buffer node, which is not constant."
359                    )
360                    mutated_buffer.add(node)
361
362    for node in edge_program.graph.nodes:
363        # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
364        if node.op == "placeholder" and (
365            is_param(edge_program, node)
366            or is_buffer(edge_program, node)
367            or is_lifted_tensor_constant(edge_program, node)
368        ):
369            if node not in mutated_buffer:
370                user_tags = set()
371                for user in node.users:
372                    user_tag = user.meta.get("delegation_tag", None)
373                    if user_tag is not None:
374                        user_tags.add(user_tag)
375                if len(user_tags) > 1:
376                    logging.info(
377                        f"The data node is used across multiple partitions, including {user_tags}. "
378                        "If the data is too large and it's not preferred to copy, please tag the "
379                        "constant node like node.['no_copy'] = True and they won't be copied."
380                    )
381                # tag the data node with the same tag as the last user
382                if len(user_tags) > 0:
383                    node.meta["delegation_tag"] = user_tags.pop()
384
385
386def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
387    """
388    Util function for partitioners. This function tags the mutated buffer nodes
389    whose users all belong within the same partition. This should be called after tagging all other nodes.
390    Any buffer which is used as input to a subgraph, will be tagged with the same tag as that
391    subgraph. Throw error when buffers is used across different partitions. That is the
392    underlying data will be owned by multiple delegates.
393    """
394    for node in edge_program.graph.nodes:
395        # Determine whether this node is a mutated buffer
396        is_mutated_buffer_node = False
397        if node.op == "placeholder" and is_buffer(edge_program, node):
398            for node_user in node.users:
399                if node_user.name in edge_program.graph_signature.buffers_to_mutate:
400                    is_mutated_buffer_node = True
401                    break
402        # This node is mutated buffer, tag it
403        if is_mutated_buffer_node:
404            user_tags = set()
405            for user in node.users:
406                user_tag = user.meta.get("delegation_tag", None)
407                if user_tag is not None:
408                    user_tags.add(user_tag)
409            if len(user_tags) > 1:
410                logging.info(
411                    f"The data node is used across multiple partitions, including {user_tags}. "
412                    "If the data is too large and it's not preferred to copy, please tag the "
413                    "constant node like node.['no_copy'] = True and they won't be copied."
414                )
415            # tag the data node with the same tag as the last user
416            if len(user_tags) > 0:
417                node.meta["delegation_tag"] = user_tags.pop()
418
419
420# TODO - style: use templated types
421class DelegateMappingBuilder:
422    """
423    Profiling helper class for building Delegate Mappings.
424    Delegate Mappings are mappings from delegate debug identifiers to node
425    debug handles. Specifically this is used to log within backend delegates
426
427    Args:
428        generated_identifiers (bool, optional): Whether identifier keys are
429            generated automatically. Defaults to False.
430    """
431
432    def __init__(self, generated_identifiers: bool = False):
433        self._generated_identifiers = generated_identifiers
434
435        # Note that the internal struct has a Set value, while the getter
436        # function returns the values as a tuple
437        self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
438            defaultdict(set)
439        )
440        self._next_index: int = 0
441
442    def get_delegate_mapping(
443        self,
444    ) -> Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]:
445        """
446        Returns:
447           Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]:
448                A map of delegate debug identifier to a list of debug handles
449                The keys (identifier) are either integers or strings
450                The values are a sorted tuple of integer debug handles
451        """
452        # pyre-ignore Warning between Union[Dict[K, V], Dict[K2, V]] vs Dict[Union[K, K2], V]
453        return {k: tuple(sorted(v)) for k, v in self._debug_handle_map.items()}
454
455    def insert_delegate_mapping_entry(
456        self,
457        nodes: Optional[Union[Node, List[Node]]] = None,
458        handles: Optional[Union[int, List[Optional[int]]]] = None,
459        identifier: Optional[Union[int, str]] = None,
460    ) -> Union[int, str]:
461        """
462        Add a new delegate mapping entry
463
464        If self._generated_identifiers = False:
465            - A new identifier must be provided, else an exception is thrown
466
467        If self._generated_identifiers = True:
468            - New identifiers are generated incrementally, 0 indexed
469            - Identifiers cannot be manually provided, else an exception is thrown
470
471        Args:
472            nodes (Union[Node, List[Node]]): A (list of) Node(s)
473            handles (Union[int, List[Optional[int]]]): A (list of) debug handle(s)
474            identifier (Optional[Union[int, str]]):
475                Debug identifier corresponding to the Node(s)
476
477        Note: Exactly one of nodes and handles must be provided
478        Note: If a debug handle is missing or None, it is skipped
479
480        Returns:
481            Union[int, str]:
482                Delegate debug identifier inserted
483        """
484
485        # Check for manual addition of identifier (with generated identifiers enabled)
486        if self._generated_identifiers and identifier is not None:
487            raise Exception(
488                f"Builders using generated identifiers can't manually add identifiers: {identifier}. Failed to add or update entry"
489            )
490
491        if identifier is not None and identifier in self._debug_handle_map:
492            raise Exception(
493                "This delegate debug identifier was already inserted. Duplicate delegate debug identifiers are not allowed."
494            )
495
496        # Check for exactly one of nodes and handles being populated
497        if not ((nodes is not None) ^ (handles is not None)):
498            raise Exception(
499                "Only one of nodes or handles must be provided. Either both were provided or neither were provided. Failed to add or update entry."
500            )
501
502        # Resolve Identifier
503        if identifier is None:
504            if self._generated_identifiers:
505                identifier = self._next_index
506                self._next_index += 1
507            else:
508                raise Exception(
509                    "No identifier provided. Failed to add or update entry."
510                )
511
512        # Collect debug handles
513        if nodes is not None:
514            new_debug_handles = {
515                node.meta.get("debug_handle")
516                for node in (nodes if isinstance(nodes, List) else [nodes])
517            }
518        else:
519            new_debug_handles = (
520                handles if isinstance(handles, (tuple, List)) else [handles]
521            )
522
523        # Filter for empty debug handles
524        filtered_debug_handles = {
525            handle for handle in new_debug_handles if handle is not None
526        }
527        if len(filtered_debug_handles) == 0:
528            raise Exception("No valid debug handles found. Failed to add entry.")
529
530        # pyre-ignore Warning from Union[int, st] keys
531        self._debug_handle_map[identifier] = filtered_debug_handles
532        return identifier
533
534
535class WhyNoPartition:
536    """
537    Simple helper class for partitioners to log why a node was not lowered.
538
539    Example usage:
540
541        # In your backend partitioner file(s)
542        why = WhyNoPartition(logger=your_backend_logger)
543
544        # hypothetical function that checks if a node can be lowered
545        if not can_be_lowered(node):
546            why(node, "This node was not lowered because ...")
547    """
548
549    def __init__(self, logger: logging.Logger):
550        self.logger = logger
551        self.node: Optional[torch.fx.Node] = None
552        self.reason: str = ""
553
554    def __call__(self, node: torch.fx.Node, reason: str) -> None:
555        self.node = node
556        self.reason = reason
557        self.logger.debug(self)
558
559    def __str__(self) -> str:
560        return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}."
561