1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3# pyre-unsafe 4 5 6# This file contains all the functions that simplify args of an op 7 8import sys 9from typing import Optional 10 11from executorch.backends.cadence.aot.pass_utils import ( 12 CadencePassAttribute, 13 register_cadence_pass, 14) 15 16from executorch.exir.dialects._ops import ops as exir_ops 17from executorch.exir.pass_base import ExportPass, ProxyValue 18 19 20@register_cadence_pass(CadencePassAttribute(opt_level=0)) 21class SimplifySliceOpPass(ExportPass): 22 """ 23 Simplify the start and end indices of slice and slice_scatter ops. 24 """ 25 26 def adjust_slice_range( 27 self, 28 length: int, 29 start: Optional[int] = None, 30 end: Optional[int] = None, 31 step: int = 1, 32 ) -> tuple[int, int]: 33 # Get the start index and end index 34 start_val = start if start is not None else 0 35 end_val = end if end is not None else sys.maxsize # 2^63 – 1 36 37 # If start_val and end_val are negative, add length to them 38 if start_val < 0: 39 start_val += length 40 if end_val < 0: 41 end_val += length 42 43 # If the start val is still outside the tensor_size along the sliced 44 # dimension, adjust it accordingly. 45 if start_val < 0: 46 start_val = 0 47 elif start_val >= length: 48 start_val = length 49 50 # If the end val is still outside the tensor_size along the sliced 51 # dimension, adjust it accordingly. 52 if end_val < start_val: 53 end_val = start_val 54 elif end_val >= length: 55 end_val = length 56 57 # Return the adjusted start and end indices 58 return (start_val, end_val) 59 60 def call_operator(self, op, args, kwargs, meta): 61 # We are only interested in slice_copy or slice_scatter ops 62 if op not in { 63 exir_ops.edge.aten.slice_copy.Tensor, 64 exir_ops.edge.aten.slice_scatter.default, 65 }: 66 return super().call_operator(op, args, kwargs, meta) 67 68 # Check if it is a slice_scatter op or not. The slice_scatter op has 69 # an extra src argument at index 1. 70 slice_scatter = op == exir_ops.edge.aten.slice_scatter.default 71 # Parse the arguments 72 # Extract the tensor to be sliced, and the slicing dimension 73 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 74 dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0 75 # Make dim non-negative 76 dim = dim if dim >= 0 else dim + in_tensor.dim() 77 length = in_tensor.size(dim) 78 79 # Get the adjusted start and end indices 80 start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None 81 end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None 82 step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1 83 (start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step) 84 85 # If the start_val is geq end_val, then we can return an empty tensor 86 # for slice op, or input for slice_scatter op. 87 if start_val >= end_val and slice_scatter: 88 return args[0] 89 if start_val >= end_val: 90 empty_shape = [x for x in in_tensor.shape if x != 0] 91 empty_shape[dim] = 0 92 return super().call_operator( 93 exir_ops.edge.aten.full.default, 94 (tuple(empty_shape), 0), 95 {"dtype": in_tensor.dtype}, 96 meta, 97 ) 98 99 # Create new args 100 new_args = ( 101 (args[0],) 102 + ((args[1],) if slice_scatter else ()) 103 + (dim, start_val, end_val, step) 104 ) 105 return super().call_operator(op, new_args, kwargs, meta) 106 107 108# This class encapsulates all the functions that simplify the op's args 109class CadenceSimplifyOpsInGraph: 110 passes = [ 111 SimplifySliceOpPass, 112 ] 113