xref: /aosp_15_r20/external/executorch/backends/arm/_passes/decompose_softmaxes_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass
12
13# For BI case
14torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
15
16# For MI case
17edge_softmax = (
18    exir_ops.edge.aten._softmax.default,
19    exir_ops.edge.aten._log_softmax.default,
20)
21
22log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)
23
24
25def get_logsoftmax_ops(op) -> tuple:
26    """
27    Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if
28    the logsoftmax op is in exir_ops torch.ops.aten.
29    """
30    if op in edge_softmax:
31        return (
32            exir_ops.edge.aten.log.default,
33            exir_ops.edge.aten.exp.default,
34            exir_ops.edge.aten.sum.dim_IntList,
35            exir_ops.edge.aten.reciprocal.default,
36            exir_ops.edge.aten.mul.Tensor,
37        )
38    if op in torch_softmax:
39        return (
40            torch.ops.aten.log.default,
41            torch.ops.aten.exp.default,
42            torch.ops.aten.sum.dim_IntList,
43            torch.ops.aten.reciprocal.default,
44            torch.ops.aten.mul.Tensor,
45        )
46    raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
47
48
49class DecomposeSoftmaxesPass(ExportPass):
50    """
51    This pass decomposes log softmax or softmax into more primitive ops.
52
53    Example:
54        %op1 = exp(x)
55        %op2 = sum(%op1, dim)
56        %op3 = reciprocal(%op2)
57        %op4 = mul(%op1, %op3)
58        (in logsoftmax case: %op5 = log(%op4))
59    """
60
61    def call_operator(self, op, args, kwargs, meta):
62        if op not in torch_softmax + edge_softmax:
63            return super().call_operator(op, args, kwargs, meta)
64
65        log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op)
66
67        _input = args[0]
68        dim = [args[1]]
69
70        op1 = super().call_operator(exp_op, (_input,), {}, meta)
71        op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
72        op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
73        op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
74        if op in log_softmax:
75            op4 = super().call_operator(log_op, (op4,), {}, meta)
76        return op4
77