xref: /aosp_15_r20/external/executorch/exir/passes/replace_broken_ops_with_function_ops_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8from typing import Dict
9
10import torch
11
12from executorch.exir.pass_base import ExportPass
13
14from torch._ops import OpOverload
15
16
17_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
18    torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
19    torch.ops.aten.t.default: torch.ops.aten.t_copy.default,
20    torch.ops.aten.view.default: torch.ops.aten.view_copy.default,
21    torch.ops.aten.expand.default: torch.ops.aten.expand_copy.default,
22    torch.ops.aten.permute.default: torch.ops.aten.permute_copy.default,
23    torch.ops.aten.squeeze.default: torch.ops.aten.squeeze_copy.default,
24    torch.ops.aten.unsqueeze.default: torch.ops.aten.unsqueeze_copy.default,
25    torch.ops.aten.slice.Tensor: torch.ops.aten.slice_copy.Tensor,
26}
27
28
29class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):
30    """
31    TODO: Our backend expects pure functions. However, some operators
32    are not functionalized properly. This pass intends to replace
33    non-functionalized operators with their functionalized variant.
34
35    TODO: this can be refactors into a general OpReplacementPass
36    """
37
38    # pyre-ignore
39    def call_operator(self, op, args, kwargs, meta):
40        if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
41            return super().call_operator(
42                _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op], args, kwargs, meta
43            )
44        return super().call_operator(op, args, kwargs, meta)
45