1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4from functools import partial 5from typing import Any, Callable, Dict, TYPE_CHECKING 6 7import torch 8 9from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS 10from .core import ( 11 _get_data, 12 _masks_match, 13 _maybe_get_mask, 14 is_masked_tensor, 15 MaskedTensor, 16) 17from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS 18from .reductions import ( 19 _apply_reduction, 20 NATIVE_REDUCE_FNS, 21 TENSOR_REDUCE_FNS, 22 TORCH_REDUCE_FNS, 23) 24from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS 25 26 27if TYPE_CHECKING: 28 from torch._ops import OpOverload 29 30 31__all__ = [] # type: ignore[var-annotated] 32 33 34def _check_args_kwargs_length( 35 args, kwargs, error_prefix, len_args=None, len_kwargs=None 36): 37 if len_args is not None and len_args != len(args): 38 raise ValueError( 39 f"{error_prefix}: len(args) must be {len_args} but got {len(args)}" 40 ) 41 if len_kwargs is not None and len_kwargs != len(kwargs): 42 raise ValueError( 43 f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}" 44 ) 45 46 47class _MaskedContiguous(torch.autograd.Function): 48 @staticmethod 49 def forward(ctx, input): 50 if not is_masked_tensor(input): 51 raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") 52 53 if input.is_contiguous(): 54 return input 55 56 data = input.get_data() 57 mask = input.get_mask() 58 59 return MaskedTensor(data.contiguous(), mask.contiguous()) 60 61 @staticmethod 62 def backward(ctx, grad_output): 63 return grad_output 64 65 66class _MaskedToDense(torch.autograd.Function): 67 @staticmethod 68 def forward(ctx, input): 69 if not is_masked_tensor(input): 70 raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") 71 72 if input.layout == torch.strided: 73 return input 74 75 ctx.layout = input.layout 76 data = input.get_data() 77 mask = input.get_mask() 78 79 return MaskedTensor(data.to_dense(), mask.to_dense()) 80 81 @staticmethod 82 def backward(ctx, grad_output): 83 layout = ctx.layout 84 85 if layout == torch.sparse_coo: 86 return grad_output.to_sparse_coo() 87 elif layout == torch.sparse_csr: 88 return grad_output.to_sparse_csr() 89 elif layout == torch.strided: 90 return grad_output.to_dense() 91 raise ValueError("to_dense: Unsupported input layout: ", layout) 92 93 94class _MaskedToSparse(torch.autograd.Function): 95 @staticmethod 96 def forward(ctx, input): 97 if not is_masked_tensor(input): 98 raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") 99 100 # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo 101 if input.layout == torch.sparse_coo: 102 return input 103 104 data = input.get_data() 105 mask = input.get_mask() 106 sparse_mask = mask.to_sparse_coo().coalesce() 107 sparse_data = data.sparse_mask(sparse_mask) 108 109 return MaskedTensor(sparse_data, sparse_mask) 110 111 @staticmethod 112 def backward(ctx, grad_output): 113 return grad_output.to_dense() 114 115 116class _MaskedToSparseCsr(torch.autograd.Function): 117 @staticmethod 118 def forward(ctx, input): 119 if not is_masked_tensor(input): 120 raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") 121 122 if input._masked_data.ndim != 2: 123 raise ValueError( 124 f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}" 125 ) 126 127 if input.layout == torch.sparse_csr: 128 return input 129 130 data = input.get_data() 131 mask = input.get_mask() 132 sparse_mask = mask.to_sparse_csr() 133 sparse_data = data.sparse_mask(sparse_mask) 134 135 return MaskedTensor(sparse_data, sparse_mask) 136 137 @staticmethod 138 def backward(ctx, grad_output): 139 return grad_output.to_dense() 140 141 142class _MaskedWhere(torch.autograd.Function): 143 @staticmethod 144 def forward(ctx, cond, self, other): 145 ctx.mark_non_differentiable(cond) 146 ctx.save_for_backward(cond) 147 return torch.ops.aten.where(cond, self, other) 148 149 @staticmethod 150 def backward(ctx, grad_output): 151 (cond,) = ctx.saved_tensors 152 153 def masked_out_like(mt): 154 return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool()) 155 156 return ( 157 None, 158 torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)), 159 torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output), 160 ) 161 162 163_MASKEDTENSOR_FUNCTION_TABLE = {} 164 165_function_fn_apply_map = { 166 ( 167 tuple(NATIVE_REDUCE_FNS), 168 tuple(TORCH_REDUCE_FNS), 169 tuple(TENSOR_REDUCE_FNS), 170 ): _apply_reduction, 171} 172 173for fn_map_list, apply_fn in _function_fn_apply_map.items(): 174 for fn_map in fn_map_list: 175 for fn in fn_map: 176 _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn) 177 178 179def register_function_func(ops): 180 """ 181 Used for registering a new __torch_function__ function to MaskedTensor 182 Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) 183 184 The code to register a new function looks like: 185 186 @register_function_func(list_of_ops) 187 def foo(func, *args, **kwargs): 188 <implementation> 189 """ 190 191 def wrapper(func): 192 for op in ops: 193 _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) 194 195 return wrapper 196 197 198@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) 199def _general_function_reductions(func, *args, **kwargs): 200 return _apply_reduction(func, *args, **kwargs) 201 202 203@register_function_func([torch.Tensor.where, torch.where]) 204def _function_where(func, *args, **kwargs): 205 _check_args_kwargs_length( 206 args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0 207 ) 208 return _MaskedWhere.apply(*args) 209 210 211@register_function_func([torch.Tensor.contiguous]) 212def _function_contiguous(func, *args, **kwargs): 213 return _MaskedContiguous.apply(args[0]) 214 215 216@register_function_func([torch.Tensor.to_dense]) 217def _function_to_dense(func, *args, **kwargs): 218 return _MaskedToDense.apply(args[0]) 219 220 221@register_function_func([torch.Tensor.to_sparse]) 222def _function_to_sparse(func, *args, **kwargs): 223 return _MaskedToSparse.apply(args[0]) 224 225 226@register_function_func([torch.Tensor.to_sparse_csr]) 227def _function_to_sparse_csr(func, *args, **kwargs): 228 return _MaskedToSparseCsr.apply(args[0]) 229 230 231_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {} 232 233 234def register_dispatch_func(aten_ops): 235 """ 236 Used for registering a new __torch_dispatch__ function to MaskedTensor 237 Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) 238 239 The code to register a new function looks like: 240 241 @register_dispatch_func(list_of_ops) 242 def foo(func, *args, **kwargs): 243 <implementation> 244 """ 245 246 def wrapper(func): 247 for aten_op in aten_ops: 248 _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) 249 250 return wrapper 251 252 253@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) 254def _general_reduction(func, *args, **kwargs): 255 return _apply_reduction(func, *args, **kwargs) 256 257 258@register_dispatch_func(PASSTHROUGH_FNS) 259def _general_passthrough(func, *args, **kwargs): 260 return _apply_pass_through_fn(func, *args, **kwargs) 261 262 263@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS) 264def _general_unary(func, *args, **kwargs): 265 return _apply_native_unary(func, *args, **kwargs) 266 267 268@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS) 269def _general_binary(func, *args, **kwargs): 270 return _apply_native_binary(func, *args, **kwargs) 271 272 273@register_dispatch_func([torch.ops.aten.stride]) 274def stride(func, *args, **kwargs): 275 return None 276 277 278@register_dispatch_func([torch.ops.aten.sym_stride]) 279def sym_stride(func, *args, **kwargs): 280 return None 281 282 283@register_dispatch_func([torch.ops.prim.layout]) 284def layout(func, *args, **kwargs): 285 return _get_data(args[0]).layout 286 287 288@register_dispatch_func([torch.ops.aten.is_contiguous]) 289def is_contiguous(func, *args, **kwargs): 290 data = _get_data(args[0]) 291 if data.is_sparse: 292 raise ValueError("MaskedTensors with sparse data do not have is_contiguous") 293 return func(data, *args[1:], **kwargs) 294 295 296@register_dispatch_func([torch.ops.aten.is_strides_like_format]) 297def is_strides_like_format(func, *args, **kwargs): 298 data = _get_data(args[0]) 299 if data.is_sparse: 300 raise ValueError( 301 "MaskedTensors with sparse data do not have is_strides_like_format" 302 ) 303 return func(data, *args[1:], **kwargs) 304 305 306@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense]) 307def is_non_overlapping_and_dense(func, *args, **kwargs): 308 data = _get_data(args[0]) 309 if data.is_sparse: 310 raise ValueError( 311 "MaskedTensors with sparse data do not have is_non_overlapping_and_dense" 312 ) 313 return func(data, *args[1:], **kwargs) 314 315 316@register_dispatch_func([torch.ops.aten.contiguous]) 317def contiguous(func, *args, **kwargs): 318 if _get_data(args[0]).is_sparse: 319 raise ValueError("MaskedTensors with sparse data do not have contiguous") 320 return _MaskedContiguous.apply(args[0]) 321 322 323@register_dispatch_func([torch.ops.aten.new_empty_strided]) 324def new_empty_strided(func, *args, **kwargs): 325 _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3) 326 data = _get_data(args[0]) 327 mask = _maybe_get_mask(args[0]) 328 if tuple(args[1]) != tuple(data.size()): 329 raise ValueError( 330 f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()" 331 ) 332 if tuple(args[2]) != tuple(data.stride()): 333 raise ValueError( 334 f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()" 335 ) 336 return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) 337 338 339@register_dispatch_func([torch.ops.aten._local_scalar_dense]) 340def _local_scalar_dense(func, *args, **kwargs): 341 if not _maybe_get_mask(args[0]): 342 raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor") 343 return torch.ops.aten._local_scalar_dense(_get_data(args[0])) 344 345 346@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone]) 347def _apply_fn_on_data(func, *args, **kwargs): 348 return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0])) 349 350 351@register_dispatch_func([torch.ops.aten._to_copy]) 352def _to_copy(func, *args, **kwargs): 353 new_data = func(_get_data(args[0]), *args[1:], **kwargs) 354 return MaskedTensor(new_data, _maybe_get_mask(args[0])) 355 356 357@register_dispatch_func([torch.ops.aten._softmax]) 358def _softmax(func, *args, **kwargs): 359 _check_args_kwargs_length( 360 args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 361 ) 362 data = _get_data(args[0]) 363 mask = _maybe_get_mask(args[0]) 364 result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) 365 return MaskedTensor(result_data, mask) 366 367 368@register_dispatch_func([torch.ops.aten.ones_like]) 369def ones_like(func, *args, **kwargs): 370 _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) 371 result_data = func(_get_data(args[0]), **kwargs) 372 return MaskedTensor(result_data, _maybe_get_mask(args[0])) 373 374 375@register_dispatch_func([torch.ops.aten._softmax_backward_data]) 376def _softmax_backward_data(func, *args, **kwargs): 377 _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) 378 grad, output, dim, input_dtype = args 379 if is_masked_tensor(grad) and is_masked_tensor(output): 380 if not _masks_match(grad, output): 381 raise ValueError( 382 "__torch_dispatch__, {func}: expected the masks of grad and output to match" 383 ) 384 grad_data = _get_data(grad) 385 new_grad_data = torch.ops.aten._masked_softmax_backward( 386 grad_data, 387 _get_data(output), 388 ~_maybe_get_mask(grad), 389 dim % grad_data.ndim, 390 ) 391 res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) 392 return res 393 else: 394 raise ValueError( 395 f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors" 396 ) 397 398 399@register_dispatch_func([torch.ops.aten.copy_]) 400def copy_(func, *args, **kwargs): 401 _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) 402 if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])): 403 raise ValueError("args[0] mask and args[1] mask must match but do not") 404 func(_get_data(args[0]), _get_data(args[1])) 405 return args[0] 406 407 408@register_dispatch_func([torch.ops.aten.where]) 409def where(func, *args, **kwargs): 410 _check_args_kwargs_length( 411 args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 412 ) 413 if not torch.is_tensor(args[0]): 414 raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") 415 mx = args[1] 416 my = args[2] 417 if not is_masked_tensor(mx): 418 mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool)) 419 if not is_masked_tensor(my): 420 my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool)) 421 new_data = func(args[0], mx.get_data(), my.get_data()) 422 new_mask = func(args[0], mx.get_mask(), my.get_mask()) 423 return MaskedTensor(new_data, new_mask) 424 425 426@register_dispatch_func([torch.ops.aten._to_sparse]) 427def _to_sparse(func, *args, **kwargs): 428 _check_args_kwargs_length( 429 args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 430 ) 431 if not torch.is_tensor(args[0]): 432 raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") 433 mt = args[0] 434 if not is_masked_tensor(mt): 435 mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool)) 436 if mt.is_sparse_coo(): 437 return mt 438 new_mask = func(_maybe_get_mask(args[0])).coalesce() 439 new_data = _get_data(args[0]).sparse_mask(new_mask) 440 return MaskedTensor(new_data, new_mask) 441 442 443@register_dispatch_func([torch.ops.aten._to_sparse_csr]) 444def _to_sparse_csr(func, *args, **kwargs): 445 _check_args_kwargs_length( 446 args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 447 ) 448 if not torch.is_tensor(args[0]): 449 raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") 450 mt = args[0] 451 if not is_masked_tensor(mt): 452 mt = MaskedTensor(mt, torch.ones_like(mt).bool()) 453 if mt.is_sparse_csr(): 454 return mt 455 new_mask = func(_maybe_get_mask(args[0])) 456 new_data = _get_data(args[0]).sparse_mask(new_mask) 457 return MaskedTensor(new_data, new_mask) 458 459 460@register_dispatch_func([torch.ops.aten._to_dense]) 461def _to_dense(func, *args, **kwargs): 462 _check_args_kwargs_length( 463 args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 464 ) 465 if not torch.is_tensor(args[0]): 466 raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") 467 mt = args[0] 468 if not is_masked_tensor(mt): 469 mt = MaskedTensor(mt, torch.ones_like(mt).bool()) 470 new_data = func(_get_data(args[0])) 471 new_mask = func(_maybe_get_mask(args[0])) 472 return MaskedTensor(new_data, new_mask) 473 474 475@register_dispatch_func([torch.ops.aten._indices]) 476def _indices(func, *args, **kwargs): 477 # Assumes data is sparse 478 _check_args_kwargs_length( 479 args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 480 ) 481 data = _get_data(args[0]).indices() 482 return MaskedTensor(data, torch.ones_like(data).bool()) 483 484 485@register_dispatch_func([torch.ops.aten._values]) 486def _values(func, *args, **kwargs): 487 _check_args_kwargs_length( 488 args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 489 ) 490 data = _get_data(args[0]).values() 491 return MaskedTensor(data, torch.ones_like(data).bool()) 492 493 494@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors]) 495def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs): 496 new_args = list(args) 497 if is_masked_tensor(args[-1]): 498 new_args[-1] = args[-1].get_data() 499 if is_masked_tensor(args[-2]): 500 new_args[-2] = args[-2].get_data() 501 502 new_data = func(*new_args, **kwargs) 503 new_args[-1] = torch.ones_like(new_args[-1]) 504 new_mask = func(*new_args, **kwargs).bool() 505 506 return MaskedTensor(new_data, new_mask) 507 508 509@register_dispatch_func([torch.ops.aten.is_same_size]) 510def is_same_size(func, *args, **kwargs): 511 _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) 512 return _get_data(args[0]).is_same_size(_get_data(args[1])) 513 514 515@register_dispatch_func([torch.ops.aten._is_any_true]) 516def _is_any_true(func, *args, **kwargs): 517 _check_args_kwargs_length( 518 args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 519 ) 520 data = _get_data(args[0]) 521 mask = _maybe_get_mask(args[0]) 522 if mask is None: 523 raise ValueError( 524 f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor" 525 ) 526 if data.dtype != torch.bool: 527 raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor") 528 if data.is_sparse: 529 raise ValueError(f"MaskedTensors with sparse data do not have {func}") 530 531 return MaskedTensor(func(data & mask), torch.tensor(True)) 532