xref: /aosp_15_r20/external/pytorch/torch/masked/maskedtensor/_ops_refs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4from functools import partial
5from typing import Any, Callable, Dict, TYPE_CHECKING
6
7import torch
8
9from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
10from .core import (
11    _get_data,
12    _masks_match,
13    _maybe_get_mask,
14    is_masked_tensor,
15    MaskedTensor,
16)
17from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
18from .reductions import (
19    _apply_reduction,
20    NATIVE_REDUCE_FNS,
21    TENSOR_REDUCE_FNS,
22    TORCH_REDUCE_FNS,
23)
24from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
25
26
27if TYPE_CHECKING:
28    from torch._ops import OpOverload
29
30
31__all__ = []  # type: ignore[var-annotated]
32
33
34def _check_args_kwargs_length(
35    args, kwargs, error_prefix, len_args=None, len_kwargs=None
36):
37    if len_args is not None and len_args != len(args):
38        raise ValueError(
39            f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
40        )
41    if len_kwargs is not None and len_kwargs != len(kwargs):
42        raise ValueError(
43            f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
44        )
45
46
47class _MaskedContiguous(torch.autograd.Function):
48    @staticmethod
49    def forward(ctx, input):
50        if not is_masked_tensor(input):
51            raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
52
53        if input.is_contiguous():
54            return input
55
56        data = input.get_data()
57        mask = input.get_mask()
58
59        return MaskedTensor(data.contiguous(), mask.contiguous())
60
61    @staticmethod
62    def backward(ctx, grad_output):
63        return grad_output
64
65
66class _MaskedToDense(torch.autograd.Function):
67    @staticmethod
68    def forward(ctx, input):
69        if not is_masked_tensor(input):
70            raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
71
72        if input.layout == torch.strided:
73            return input
74
75        ctx.layout = input.layout
76        data = input.get_data()
77        mask = input.get_mask()
78
79        return MaskedTensor(data.to_dense(), mask.to_dense())
80
81    @staticmethod
82    def backward(ctx, grad_output):
83        layout = ctx.layout
84
85        if layout == torch.sparse_coo:
86            return grad_output.to_sparse_coo()
87        elif layout == torch.sparse_csr:
88            return grad_output.to_sparse_csr()
89        elif layout == torch.strided:
90            return grad_output.to_dense()
91        raise ValueError("to_dense: Unsupported input layout: ", layout)
92
93
94class _MaskedToSparse(torch.autograd.Function):
95    @staticmethod
96    def forward(ctx, input):
97        if not is_masked_tensor(input):
98            raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
99
100        # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo
101        if input.layout == torch.sparse_coo:
102            return input
103
104        data = input.get_data()
105        mask = input.get_mask()
106        sparse_mask = mask.to_sparse_coo().coalesce()
107        sparse_data = data.sparse_mask(sparse_mask)
108
109        return MaskedTensor(sparse_data, sparse_mask)
110
111    @staticmethod
112    def backward(ctx, grad_output):
113        return grad_output.to_dense()
114
115
116class _MaskedToSparseCsr(torch.autograd.Function):
117    @staticmethod
118    def forward(ctx, input):
119        if not is_masked_tensor(input):
120            raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
121
122        if input._masked_data.ndim != 2:
123            raise ValueError(
124                f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
125            )
126
127        if input.layout == torch.sparse_csr:
128            return input
129
130        data = input.get_data()
131        mask = input.get_mask()
132        sparse_mask = mask.to_sparse_csr()
133        sparse_data = data.sparse_mask(sparse_mask)
134
135        return MaskedTensor(sparse_data, sparse_mask)
136
137    @staticmethod
138    def backward(ctx, grad_output):
139        return grad_output.to_dense()
140
141
142class _MaskedWhere(torch.autograd.Function):
143    @staticmethod
144    def forward(ctx, cond, self, other):
145        ctx.mark_non_differentiable(cond)
146        ctx.save_for_backward(cond)
147        return torch.ops.aten.where(cond, self, other)
148
149    @staticmethod
150    def backward(ctx, grad_output):
151        (cond,) = ctx.saved_tensors
152
153        def masked_out_like(mt):
154            return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool())
155
156        return (
157            None,
158            torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)),
159            torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output),
160        )
161
162
163_MASKEDTENSOR_FUNCTION_TABLE = {}
164
165_function_fn_apply_map = {
166    (
167        tuple(NATIVE_REDUCE_FNS),
168        tuple(TORCH_REDUCE_FNS),
169        tuple(TENSOR_REDUCE_FNS),
170    ): _apply_reduction,
171}
172
173for fn_map_list, apply_fn in _function_fn_apply_map.items():
174    for fn_map in fn_map_list:
175        for fn in fn_map:
176            _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn)
177
178
179def register_function_func(ops):
180    """
181    Used for registering a new __torch_function__ function to MaskedTensor
182    Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
183
184    The code to register a new function looks like:
185
186    @register_function_func(list_of_ops)
187    def foo(func, *args, **kwargs):
188        <implementation>
189    """
190
191    def wrapper(func):
192        for op in ops:
193            _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
194
195    return wrapper
196
197
198@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
199def _general_function_reductions(func, *args, **kwargs):
200    return _apply_reduction(func, *args, **kwargs)
201
202
203@register_function_func([torch.Tensor.where, torch.where])
204def _function_where(func, *args, **kwargs):
205    _check_args_kwargs_length(
206        args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
207    )
208    return _MaskedWhere.apply(*args)
209
210
211@register_function_func([torch.Tensor.contiguous])
212def _function_contiguous(func, *args, **kwargs):
213    return _MaskedContiguous.apply(args[0])
214
215
216@register_function_func([torch.Tensor.to_dense])
217def _function_to_dense(func, *args, **kwargs):
218    return _MaskedToDense.apply(args[0])
219
220
221@register_function_func([torch.Tensor.to_sparse])
222def _function_to_sparse(func, *args, **kwargs):
223    return _MaskedToSparse.apply(args[0])
224
225
226@register_function_func([torch.Tensor.to_sparse_csr])
227def _function_to_sparse_csr(func, *args, **kwargs):
228    return _MaskedToSparseCsr.apply(args[0])
229
230
231_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {}
232
233
234def register_dispatch_func(aten_ops):
235    """
236    Used for registering a new __torch_dispatch__ function to MaskedTensor
237    Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
238
239    The code to register a new function looks like:
240
241    @register_dispatch_func(list_of_ops)
242    def foo(func, *args, **kwargs):
243        <implementation>
244    """
245
246    def wrapper(func):
247        for aten_op in aten_ops:
248            _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
249
250    return wrapper
251
252
253@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
254def _general_reduction(func, *args, **kwargs):
255    return _apply_reduction(func, *args, **kwargs)
256
257
258@register_dispatch_func(PASSTHROUGH_FNS)
259def _general_passthrough(func, *args, **kwargs):
260    return _apply_pass_through_fn(func, *args, **kwargs)
261
262
263@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS)
264def _general_unary(func, *args, **kwargs):
265    return _apply_native_unary(func, *args, **kwargs)
266
267
268@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
269def _general_binary(func, *args, **kwargs):
270    return _apply_native_binary(func, *args, **kwargs)
271
272
273@register_dispatch_func([torch.ops.aten.stride])
274def stride(func, *args, **kwargs):
275    return None
276
277
278@register_dispatch_func([torch.ops.aten.sym_stride])
279def sym_stride(func, *args, **kwargs):
280    return None
281
282
283@register_dispatch_func([torch.ops.prim.layout])
284def layout(func, *args, **kwargs):
285    return _get_data(args[0]).layout
286
287
288@register_dispatch_func([torch.ops.aten.is_contiguous])
289def is_contiguous(func, *args, **kwargs):
290    data = _get_data(args[0])
291    if data.is_sparse:
292        raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
293    return func(data, *args[1:], **kwargs)
294
295
296@register_dispatch_func([torch.ops.aten.is_strides_like_format])
297def is_strides_like_format(func, *args, **kwargs):
298    data = _get_data(args[0])
299    if data.is_sparse:
300        raise ValueError(
301            "MaskedTensors with sparse data do not have is_strides_like_format"
302        )
303    return func(data, *args[1:], **kwargs)
304
305
306@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
307def is_non_overlapping_and_dense(func, *args, **kwargs):
308    data = _get_data(args[0])
309    if data.is_sparse:
310        raise ValueError(
311            "MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
312        )
313    return func(data, *args[1:], **kwargs)
314
315
316@register_dispatch_func([torch.ops.aten.contiguous])
317def contiguous(func, *args, **kwargs):
318    if _get_data(args[0]).is_sparse:
319        raise ValueError("MaskedTensors with sparse data do not have contiguous")
320    return _MaskedContiguous.apply(args[0])
321
322
323@register_dispatch_func([torch.ops.aten.new_empty_strided])
324def new_empty_strided(func, *args, **kwargs):
325    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3)
326    data = _get_data(args[0])
327    mask = _maybe_get_mask(args[0])
328    if tuple(args[1]) != tuple(data.size()):
329        raise ValueError(
330            f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
331        )
332    if tuple(args[2]) != tuple(data.stride()):
333        raise ValueError(
334            f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
335        )
336    return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
337
338
339@register_dispatch_func([torch.ops.aten._local_scalar_dense])
340def _local_scalar_dense(func, *args, **kwargs):
341    if not _maybe_get_mask(args[0]):
342        raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor")
343    return torch.ops.aten._local_scalar_dense(_get_data(args[0]))
344
345
346@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone])
347def _apply_fn_on_data(func, *args, **kwargs):
348    return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0]))
349
350
351@register_dispatch_func([torch.ops.aten._to_copy])
352def _to_copy(func, *args, **kwargs):
353    new_data = func(_get_data(args[0]), *args[1:], **kwargs)
354    return MaskedTensor(new_data, _maybe_get_mask(args[0]))
355
356
357@register_dispatch_func([torch.ops.aten._softmax])
358def _softmax(func, *args, **kwargs):
359    _check_args_kwargs_length(
360        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
361    )
362    data = _get_data(args[0])
363    mask = _maybe_get_mask(args[0])
364    result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
365    return MaskedTensor(result_data, mask)
366
367
368@register_dispatch_func([torch.ops.aten.ones_like])
369def ones_like(func, *args, **kwargs):
370    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
371    result_data = func(_get_data(args[0]), **kwargs)
372    return MaskedTensor(result_data, _maybe_get_mask(args[0]))
373
374
375@register_dispatch_func([torch.ops.aten._softmax_backward_data])
376def _softmax_backward_data(func, *args, **kwargs):
377    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
378    grad, output, dim, input_dtype = args
379    if is_masked_tensor(grad) and is_masked_tensor(output):
380        if not _masks_match(grad, output):
381            raise ValueError(
382                "__torch_dispatch__, {func}: expected the masks of grad and output to match"
383            )
384        grad_data = _get_data(grad)
385        new_grad_data = torch.ops.aten._masked_softmax_backward(
386            grad_data,
387            _get_data(output),
388            ~_maybe_get_mask(grad),
389            dim % grad_data.ndim,
390        )
391        res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
392        return res
393    else:
394        raise ValueError(
395            f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
396        )
397
398
399@register_dispatch_func([torch.ops.aten.copy_])
400def copy_(func, *args, **kwargs):
401    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
402    if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])):
403        raise ValueError("args[0] mask and args[1] mask must match but do not")
404    func(_get_data(args[0]), _get_data(args[1]))
405    return args[0]
406
407
408@register_dispatch_func([torch.ops.aten.where])
409def where(func, *args, **kwargs):
410    _check_args_kwargs_length(
411        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
412    )
413    if not torch.is_tensor(args[0]):
414        raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
415    mx = args[1]
416    my = args[2]
417    if not is_masked_tensor(mx):
418        mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool))
419    if not is_masked_tensor(my):
420        my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool))
421    new_data = func(args[0], mx.get_data(), my.get_data())
422    new_mask = func(args[0], mx.get_mask(), my.get_mask())
423    return MaskedTensor(new_data, new_mask)
424
425
426@register_dispatch_func([torch.ops.aten._to_sparse])
427def _to_sparse(func, *args, **kwargs):
428    _check_args_kwargs_length(
429        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
430    )
431    if not torch.is_tensor(args[0]):
432        raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
433    mt = args[0]
434    if not is_masked_tensor(mt):
435        mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool))
436    if mt.is_sparse_coo():
437        return mt
438    new_mask = func(_maybe_get_mask(args[0])).coalesce()
439    new_data = _get_data(args[0]).sparse_mask(new_mask)
440    return MaskedTensor(new_data, new_mask)
441
442
443@register_dispatch_func([torch.ops.aten._to_sparse_csr])
444def _to_sparse_csr(func, *args, **kwargs):
445    _check_args_kwargs_length(
446        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
447    )
448    if not torch.is_tensor(args[0]):
449        raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
450    mt = args[0]
451    if not is_masked_tensor(mt):
452        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
453    if mt.is_sparse_csr():
454        return mt
455    new_mask = func(_maybe_get_mask(args[0]))
456    new_data = _get_data(args[0]).sparse_mask(new_mask)
457    return MaskedTensor(new_data, new_mask)
458
459
460@register_dispatch_func([torch.ops.aten._to_dense])
461def _to_dense(func, *args, **kwargs):
462    _check_args_kwargs_length(
463        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
464    )
465    if not torch.is_tensor(args[0]):
466        raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
467    mt = args[0]
468    if not is_masked_tensor(mt):
469        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
470    new_data = func(_get_data(args[0]))
471    new_mask = func(_maybe_get_mask(args[0]))
472    return MaskedTensor(new_data, new_mask)
473
474
475@register_dispatch_func([torch.ops.aten._indices])
476def _indices(func, *args, **kwargs):
477    # Assumes data is sparse
478    _check_args_kwargs_length(
479        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
480    )
481    data = _get_data(args[0]).indices()
482    return MaskedTensor(data, torch.ones_like(data).bool())
483
484
485@register_dispatch_func([torch.ops.aten._values])
486def _values(func, *args, **kwargs):
487    _check_args_kwargs_length(
488        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
489    )
490    data = _get_data(args[0]).values()
491    return MaskedTensor(data, torch.ones_like(data).bool())
492
493
494@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors])
495def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs):
496    new_args = list(args)
497    if is_masked_tensor(args[-1]):
498        new_args[-1] = args[-1].get_data()
499    if is_masked_tensor(args[-2]):
500        new_args[-2] = args[-2].get_data()
501
502    new_data = func(*new_args, **kwargs)
503    new_args[-1] = torch.ones_like(new_args[-1])
504    new_mask = func(*new_args, **kwargs).bool()
505
506    return MaskedTensor(new_data, new_mask)
507
508
509@register_dispatch_func([torch.ops.aten.is_same_size])
510def is_same_size(func, *args, **kwargs):
511    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
512    return _get_data(args[0]).is_same_size(_get_data(args[1]))
513
514
515@register_dispatch_func([torch.ops.aten._is_any_true])
516def _is_any_true(func, *args, **kwargs):
517    _check_args_kwargs_length(
518        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
519    )
520    data = _get_data(args[0])
521    mask = _maybe_get_mask(args[0])
522    if mask is None:
523        raise ValueError(
524            f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor"
525        )
526    if data.dtype != torch.bool:
527        raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor")
528    if data.is_sparse:
529        raise ValueError(f"MaskedTensors with sparse data do not have {func}")
530
531    return MaskedTensor(func(data & mask), torch.tensor(True))
532