xref: /aosp_15_r20/external/executorch/backends/cadence/aot/simplify_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3# pyre-unsafe
4
5
6# This file contains all the functions that simplify args of an op
7
8import sys
9from typing import Optional
10
11from executorch.backends.cadence.aot.pass_utils import (
12    CadencePassAttribute,
13    register_cadence_pass,
14)
15
16from executorch.exir.dialects._ops import ops as exir_ops
17from executorch.exir.pass_base import ExportPass, ProxyValue
18
19
20@register_cadence_pass(CadencePassAttribute(opt_level=0))
21class SimplifySliceOpPass(ExportPass):
22    """
23    Simplify the start and end indices of slice and slice_scatter ops.
24    """
25
26    def adjust_slice_range(
27        self,
28        length: int,
29        start: Optional[int] = None,
30        end: Optional[int] = None,
31        step: int = 1,
32    ) -> tuple[int, int]:
33        # Get the start index and end index
34        start_val = start if start is not None else 0
35        end_val = end if end is not None else sys.maxsize  # 2^63 – 1
36
37        # If start_val and end_val are negative, add length to them
38        if start_val < 0:
39            start_val += length
40        if end_val < 0:
41            end_val += length
42
43        # If the start val is still outside the tensor_size along the sliced
44        # dimension, adjust it accordingly.
45        if start_val < 0:
46            start_val = 0
47        elif start_val >= length:
48            start_val = length
49
50        # If the end val is still outside the tensor_size along the sliced
51        # dimension, adjust it accordingly.
52        if end_val < start_val:
53            end_val = start_val
54        elif end_val >= length:
55            end_val = length
56
57        # Return the adjusted start and end indices
58        return (start_val, end_val)
59
60    def call_operator(self, op, args, kwargs, meta):
61        # We are only interested in slice_copy or slice_scatter ops
62        if op not in {
63            exir_ops.edge.aten.slice_copy.Tensor,
64            exir_ops.edge.aten.slice_scatter.default,
65        }:
66            return super().call_operator(op, args, kwargs, meta)
67
68        # Check if it is a slice_scatter op or not. The slice_scatter op has
69        # an extra src argument at index 1.
70        slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
71        # Parse the arguments
72        # Extract the tensor to be sliced, and the slicing dimension
73        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
74        dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
75        # Make dim non-negative
76        dim = dim if dim >= 0 else dim + in_tensor.dim()
77        length = in_tensor.size(dim)
78
79        # Get the adjusted start and end indices
80        start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None
81        end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None
82        step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1
83        (start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step)
84
85        # If the start_val is geq end_val, then we can return an empty tensor
86        # for slice op, or input for slice_scatter op.
87        if start_val >= end_val and slice_scatter:
88            return args[0]
89        if start_val >= end_val:
90            empty_shape = [x for x in in_tensor.shape if x != 0]
91            empty_shape[dim] = 0
92            return super().call_operator(
93                exir_ops.edge.aten.full.default,
94                (tuple(empty_shape), 0),
95                {"dtype": in_tensor.dtype},
96                meta,
97            )
98
99        # Create new args
100        new_args = (
101            (args[0],)
102            + ((args[1],) if slice_scatter else ())
103            + (dim, start_val, end_val, step)
104        )
105        return super().call_operator(op, new_args, kwargs, meta)
106
107
108# This class encapsulates all the functions that simplify the op's args
109class CadenceSimplifyOpsInGraph:
110    passes = [
111        SimplifySliceOpPass,
112    ]
113