xref: /aosp_15_r20/external/pytorch/torch/masked/maskedtensor/binary.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4import torch
5
6from .core import (
7    _map_mt_args_kwargs,
8    _masks_match,
9    _tensors_match,
10    _wrap_result,
11    is_masked_tensor,
12)
13
14
15__all__ = []  # type: ignore[var-annotated]
16
17BINARY_NAMES = [
18    "add",
19    "atan2",
20    "arctan2",
21    "bitwise_and",
22    "bitwise_or",
23    "bitwise_xor",
24    "bitwise_left_shift",
25    "bitwise_right_shift",
26    "div",
27    "divide",
28    "floor_divide",
29    "fmod",
30    "logaddexp",
31    "logaddexp2",
32    "mul",
33    "multiply",
34    "nextafter",
35    "remainder",
36    "sub",
37    "subtract",
38    "true_divide",
39    "eq",
40    "ne",
41    "le",
42    "ge",
43    "greater",
44    "greater_equal",
45    "gt",
46    "less_equal",
47    "lt",
48    "less",
49    "maximum",
50    "minimum",
51    "fmax",
52    "fmin",
53    "not_equal",
54]
55
56INPLACE_BINARY_NAMES = [
57    n + "_"
58    for n in (
59        list(
60            set(BINARY_NAMES)
61            - {
62                "logaddexp",
63                "logaddexp2",
64                "equal",
65                "fmin",
66                "minimum",
67                "maximum",
68                "fmax",
69            }
70        )
71    )
72]
73
74
75def _get_at_least_one_mask(a, b):
76    if not is_masked_tensor(a) and not is_masked_tensor(b):
77        raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
78    if not _masks_match(a, b):
79        raise ValueError("a and b must have matching masks")
80    if is_masked_tensor(a):
81        return a.get_mask()
82    return b.get_mask()
83
84
85def _binary_helper(fn, args, kwargs, inplace):
86    if len(kwargs) != 0:
87        raise ValueError("len(kwargs) must equal 0")
88    for a in args[2:]:
89        if torch.is_tensor(a):
90            raise TypeError(
91                "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
92            )
93
94    if not _masks_match(*args[:2]):
95        raise ValueError(
96            "Input masks must match. If you need support for this, please open an issue on Github."
97        )
98
99    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
100    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
101
102    args0_layout = data_args[0].layout
103    same_layout = (
104        torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
105    ) and (args0_layout == data_args[1].layout)
106
107    if args0_layout == torch.sparse_coo:
108        if same_layout:
109            if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
110                raise ValueError(
111                    "sparse_coo indices must match. If you need support for this, please open an issue on Github."
112                )
113            if data_args[0].size() != data_args[1].size():
114                raise ValueError(
115                    "input1 and input2 must have the same size for binary functions."
116                )
117
118            data_args[1] = data_args[1].values()
119
120        i = data_args[0].indices()
121        size = data_args[0].size()
122        data_args[0] = data_args[0].values()
123        v = fn(*data_args)
124        result_data = torch.sparse_coo_tensor(i, v, size)
125
126    elif args0_layout == torch.sparse_csr:
127        if same_layout:
128            if not (
129                _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
130                and _tensors_match(
131                    data_args[0].col_indices(), data_args[1].col_indices()
132                )
133            ):
134                raise ValueError(
135                    "sparse_csr indices must match. If you need support for this, please open an issue on Github."
136                )
137
138            data_args[1] = data_args[1].values()
139
140        crow = data_args[0].crow_indices()
141        col = data_args[0].col_indices()
142        data_args[0] = data_args[0].values()
143        v = fn(*data_args)
144        result_data = torch.sparse_csr_tensor(crow, col, v)
145
146    else:
147        result_data = fn(*data_args)
148
149    if inplace:
150        args[0]._set_data_mask(result_data, mask_args[0])
151        return args[0]
152    else:
153        result_mask = _get_at_least_one_mask(*args[:2])
154        # sparse tensors don't have strides so we can only expand if the layout is strided
155        if args0_layout == torch.strided:
156            result_mask = result_mask.expand_as(result_data)
157        return _wrap_result(result_data, result_mask)
158
159
160def _torch_binary(fn_name):
161    fn = getattr(torch.ops.aten, fn_name)
162
163    def binary_fn(*args, **kwargs):
164        return _binary_helper(fn, args, kwargs, inplace=False)
165
166    return binary_fn
167
168
169def _torch_inplace_binary(fn_name):
170    fn = getattr(torch.ops.aten, fn_name)
171
172    def binary_fn(*args, **kwargs):
173        return _binary_helper(fn, args, kwargs, inplace=True)
174
175    return binary_fn
176
177
178NATIVE_BINARY_MAP = {
179    getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
180}
181NATIVE_INPLACE_BINARY_MAP = {
182    getattr(torch.ops.aten, name): _torch_inplace_binary(name)
183    for name in INPLACE_BINARY_NAMES
184}
185
186NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
187NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
188
189
190def _is_native_binary(fn):
191    return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
192
193
194def _apply_native_binary(fn, *args, **kwargs):
195    if fn in NATIVE_BINARY_FNS:
196        return NATIVE_BINARY_MAP[fn](*args, **kwargs)
197    if fn in NATIVE_INPLACE_BINARY_FNS:
198        return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
199    return NotImplemented
200