1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4import torch 5 6from .core import ( 7 _map_mt_args_kwargs, 8 _masks_match, 9 _tensors_match, 10 _wrap_result, 11 is_masked_tensor, 12) 13 14 15__all__ = [] # type: ignore[var-annotated] 16 17BINARY_NAMES = [ 18 "add", 19 "atan2", 20 "arctan2", 21 "bitwise_and", 22 "bitwise_or", 23 "bitwise_xor", 24 "bitwise_left_shift", 25 "bitwise_right_shift", 26 "div", 27 "divide", 28 "floor_divide", 29 "fmod", 30 "logaddexp", 31 "logaddexp2", 32 "mul", 33 "multiply", 34 "nextafter", 35 "remainder", 36 "sub", 37 "subtract", 38 "true_divide", 39 "eq", 40 "ne", 41 "le", 42 "ge", 43 "greater", 44 "greater_equal", 45 "gt", 46 "less_equal", 47 "lt", 48 "less", 49 "maximum", 50 "minimum", 51 "fmax", 52 "fmin", 53 "not_equal", 54] 55 56INPLACE_BINARY_NAMES = [ 57 n + "_" 58 for n in ( 59 list( 60 set(BINARY_NAMES) 61 - { 62 "logaddexp", 63 "logaddexp2", 64 "equal", 65 "fmin", 66 "minimum", 67 "maximum", 68 "fmax", 69 } 70 ) 71 ) 72] 73 74 75def _get_at_least_one_mask(a, b): 76 if not is_masked_tensor(a) and not is_masked_tensor(b): 77 raise TypeError("At least one of `a` and `b` must be a MaskedTensor") 78 if not _masks_match(a, b): 79 raise ValueError("a and b must have matching masks") 80 if is_masked_tensor(a): 81 return a.get_mask() 82 return b.get_mask() 83 84 85def _binary_helper(fn, args, kwargs, inplace): 86 if len(kwargs) != 0: 87 raise ValueError("len(kwargs) must equal 0") 88 for a in args[2:]: 89 if torch.is_tensor(a): 90 raise TypeError( 91 "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs" 92 ) 93 94 if not _masks_match(*args[:2]): 95 raise ValueError( 96 "Input masks must match. If you need support for this, please open an issue on Github." 97 ) 98 99 data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) 100 mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) 101 102 args0_layout = data_args[0].layout 103 same_layout = ( 104 torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1]) 105 ) and (args0_layout == data_args[1].layout) 106 107 if args0_layout == torch.sparse_coo: 108 if same_layout: 109 if not _tensors_match(data_args[0].indices(), data_args[1].indices()): 110 raise ValueError( 111 "sparse_coo indices must match. If you need support for this, please open an issue on Github." 112 ) 113 if data_args[0].size() != data_args[1].size(): 114 raise ValueError( 115 "input1 and input2 must have the same size for binary functions." 116 ) 117 118 data_args[1] = data_args[1].values() 119 120 i = data_args[0].indices() 121 size = data_args[0].size() 122 data_args[0] = data_args[0].values() 123 v = fn(*data_args) 124 result_data = torch.sparse_coo_tensor(i, v, size) 125 126 elif args0_layout == torch.sparse_csr: 127 if same_layout: 128 if not ( 129 _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices()) 130 and _tensors_match( 131 data_args[0].col_indices(), data_args[1].col_indices() 132 ) 133 ): 134 raise ValueError( 135 "sparse_csr indices must match. If you need support for this, please open an issue on Github." 136 ) 137 138 data_args[1] = data_args[1].values() 139 140 crow = data_args[0].crow_indices() 141 col = data_args[0].col_indices() 142 data_args[0] = data_args[0].values() 143 v = fn(*data_args) 144 result_data = torch.sparse_csr_tensor(crow, col, v) 145 146 else: 147 result_data = fn(*data_args) 148 149 if inplace: 150 args[0]._set_data_mask(result_data, mask_args[0]) 151 return args[0] 152 else: 153 result_mask = _get_at_least_one_mask(*args[:2]) 154 # sparse tensors don't have strides so we can only expand if the layout is strided 155 if args0_layout == torch.strided: 156 result_mask = result_mask.expand_as(result_data) 157 return _wrap_result(result_data, result_mask) 158 159 160def _torch_binary(fn_name): 161 fn = getattr(torch.ops.aten, fn_name) 162 163 def binary_fn(*args, **kwargs): 164 return _binary_helper(fn, args, kwargs, inplace=False) 165 166 return binary_fn 167 168 169def _torch_inplace_binary(fn_name): 170 fn = getattr(torch.ops.aten, fn_name) 171 172 def binary_fn(*args, **kwargs): 173 return _binary_helper(fn, args, kwargs, inplace=True) 174 175 return binary_fn 176 177 178NATIVE_BINARY_MAP = { 179 getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES 180} 181NATIVE_INPLACE_BINARY_MAP = { 182 getattr(torch.ops.aten, name): _torch_inplace_binary(name) 183 for name in INPLACE_BINARY_NAMES 184} 185 186NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys()) 187NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys()) 188 189 190def _is_native_binary(fn): 191 return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS 192 193 194def _apply_native_binary(fn, *args, **kwargs): 195 if fn in NATIVE_BINARY_FNS: 196 return NATIVE_BINARY_MAP[fn](*args, **kwargs) 197 if fn in NATIVE_INPLACE_BINARY_FNS: 198 return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs) 199 return NotImplemented 200