1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass 12 13# For BI case 14torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) 15 16# For MI case 17edge_softmax = ( 18 exir_ops.edge.aten._softmax.default, 19 exir_ops.edge.aten._log_softmax.default, 20) 21 22log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default) 23 24 25def get_logsoftmax_ops(op) -> tuple: 26 """ 27 Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if 28 the logsoftmax op is in exir_ops torch.ops.aten. 29 """ 30 if op in edge_softmax: 31 return ( 32 exir_ops.edge.aten.log.default, 33 exir_ops.edge.aten.exp.default, 34 exir_ops.edge.aten.sum.dim_IntList, 35 exir_ops.edge.aten.reciprocal.default, 36 exir_ops.edge.aten.mul.Tensor, 37 ) 38 if op in torch_softmax: 39 return ( 40 torch.ops.aten.log.default, 41 torch.ops.aten.exp.default, 42 torch.ops.aten.sum.dim_IntList, 43 torch.ops.aten.reciprocal.default, 44 torch.ops.aten.mul.Tensor, 45 ) 46 raise RuntimeError(f"Can't get softmax decomposition ops for op {op}") 47 48 49class DecomposeSoftmaxesPass(ExportPass): 50 """ 51 This pass decomposes log softmax or softmax into more primitive ops. 52 53 Example: 54 %op1 = exp(x) 55 %op2 = sum(%op1, dim) 56 %op3 = reciprocal(%op2) 57 %op4 = mul(%op1, %op3) 58 (in logsoftmax case: %op5 = log(%op4)) 59 """ 60 61 def call_operator(self, op, args, kwargs, meta): 62 if op not in torch_softmax + edge_softmax: 63 return super().call_operator(op, args, kwargs, meta) 64 65 log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op) 66 67 _input = args[0] 68 dim = [args[1]] 69 70 op1 = super().call_operator(exp_op, (_input,), {}, meta) 71 op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta) 72 op3 = super().call_operator(reciprocal_op, (op2,), {}, meta) 73 op4 = super().call_operator(mul_op, (op1, op3), {}, meta) 74 if op in log_softmax: 75 op4 = super().call_operator(log_op, (op4,), {}, meta) 76 return op4 77