1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8from typing import Dict 9 10import torch 11 12from executorch.exir.pass_base import ExportPass 13 14from torch._ops import OpOverload 15 16 17_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { 18 torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, 19 torch.ops.aten.t.default: torch.ops.aten.t_copy.default, 20 torch.ops.aten.view.default: torch.ops.aten.view_copy.default, 21 torch.ops.aten.expand.default: torch.ops.aten.expand_copy.default, 22 torch.ops.aten.permute.default: torch.ops.aten.permute_copy.default, 23 torch.ops.aten.squeeze.default: torch.ops.aten.squeeze_copy.default, 24 torch.ops.aten.unsqueeze.default: torch.ops.aten.unsqueeze_copy.default, 25 torch.ops.aten.slice.Tensor: torch.ops.aten.slice_copy.Tensor, 26} 27 28 29class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass): 30 """ 31 TODO: Our backend expects pure functions. However, some operators 32 are not functionalized properly. This pass intends to replace 33 non-functionalized operators with their functionalized variant. 34 35 TODO: this can be refactors into a general OpReplacementPass 36 """ 37 38 # pyre-ignore 39 def call_operator(self, op, args, kwargs, meta): 40 if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: 41 return super().call_operator( 42 _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op], args, kwargs, meta 43 ) 44 return super().call_operator(op, args, kwargs, meta) 45