1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3""" 4These are functions that should simply be applied to both mask and data. 5Take select or stack as an example. This operation can be applied to 6both the mask and data of a MaskedTensor and the result wrapped into 7a new MaskedTensor as a result. 8""" 9 10import torch 11 12from .core import _map_mt_args_kwargs, _wrap_result 13 14 15__all__ = [] # type: ignore[var-annotated] 16 17 18PASSTHROUGH_FNS = [ 19 torch.ops.aten.select, 20 torch.ops.aten.transpose, 21 torch.ops.aten.split, 22 torch.ops.aten.t, 23 torch.ops.aten.slice, 24 torch.ops.aten.slice_backward, 25 torch.ops.aten.select_backward, 26 torch.ops.aten.index, 27 torch.ops.aten.expand, 28 torch.ops.aten.view, 29 torch.ops.aten._unsafe_view, 30 torch.ops.aten._reshape_alias, 31 torch.ops.aten.cat, 32 torch.ops.aten.unsqueeze, 33 torch.ops.aten.unfold, 34 torch.ops.aten.unfold_backward, 35 torch.ops.aten.im2col, 36 torch.ops.aten.col2im, 37 torch.ops.aten.stack, 38] 39 40 41def _is_pass_through_fn(fn): 42 return fn in PASSTHROUGH_FNS 43 44 45def _apply_pass_through_fn(fn, *args, **kwargs): 46 data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) 47 result_data = fn(*data_args, **data_kwargs) 48 mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) 49 result_mask = fn(*mask_args, **mask_kwargs) 50 return _wrap_result(result_data, result_mask) 51