xref: /aosp_15_r20/external/pytorch/torch/masked/maskedtensor/passthrough.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3"""
4These are functions that should simply be applied to both mask and data.
5Take select or stack as an example. This operation can be applied to
6both the mask and data of a MaskedTensor and the result wrapped into
7a new MaskedTensor as a result.
8"""
9
10import torch
11
12from .core import _map_mt_args_kwargs, _wrap_result
13
14
15__all__ = []  # type: ignore[var-annotated]
16
17
18PASSTHROUGH_FNS = [
19    torch.ops.aten.select,
20    torch.ops.aten.transpose,
21    torch.ops.aten.split,
22    torch.ops.aten.t,
23    torch.ops.aten.slice,
24    torch.ops.aten.slice_backward,
25    torch.ops.aten.select_backward,
26    torch.ops.aten.index,
27    torch.ops.aten.expand,
28    torch.ops.aten.view,
29    torch.ops.aten._unsafe_view,
30    torch.ops.aten._reshape_alias,
31    torch.ops.aten.cat,
32    torch.ops.aten.unsqueeze,
33    torch.ops.aten.unfold,
34    torch.ops.aten.unfold_backward,
35    torch.ops.aten.im2col,
36    torch.ops.aten.col2im,
37    torch.ops.aten.stack,
38]
39
40
41def _is_pass_through_fn(fn):
42    return fn in PASSTHROUGH_FNS
43
44
45def _apply_pass_through_fn(fn, *args, **kwargs):
46    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
47    result_data = fn(*data_args, **data_kwargs)
48    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
49    result_mask = fn(*mask_args, **mask_kwargs)
50    return _wrap_result(result_data, result_mask)
51