1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# reference python implementations for C ops 8import torch 9from functorch._C import dim as _C 10 11from . import op_properties 12from .batch_tensor import _enable_layers 13from .tree_map import tree_flatten, tree_map 14 15 16DimList = _C.DimList 17import operator 18from functools import reduce 19 20 21# use dict to avoid writing C++ bindings for set 22pointwise = set(op_properties.pointwise) 23 24 25def prod(x): 26 return reduce(operator.mul, x, 1) 27 28 29def _wrap_dim(d, N, keepdim): 30 from . import Dim 31 32 if isinstance(d, Dim): 33 assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" 34 return d 35 elif d >= 0: 36 return d - N 37 else: 38 return d 39 40 41def _dims(d, N, keepdim, single_dim): 42 from . import Dim 43 44 if isinstance(d, (Dim, int)): 45 return ltuple((_wrap_dim(d, N, keepdim),)) 46 assert not single_dim, f"expected a single dimension or int but found: {d}" 47 return ltuple(_wrap_dim(x, N, keepdim) for x in d) 48 49 50def _bind_dims_to_size(lhs_size, rhs, lhs_debug): 51 from . import DimensionMismatchError 52 53 not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) 54 if len(not_bound) == 1: 55 idx, d = not_bound[0] 56 rhs_so_far = prod(r.size for r in rhs if r.is_bound) 57 if lhs_size % rhs_so_far != 0: 58 rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) 59 raise DimensionMismatchError( 60 f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" 61 ) 62 new_size = lhs_size // rhs_so_far 63 d.size = new_size 64 elif len(not_bound) > 1: 65 rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) 66 raise DimensionMismatchError( 67 f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" 68 ) 69 else: 70 rhs_size = prod(r.size for r in rhs) 71 if lhs_size != rhs_size: 72 raise DimensionMismatchError( 73 f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" 74 ) 75 76 77def _tensor_levels(inp): 78 from . import _Tensor 79 80 if isinstance(inp, _Tensor): 81 return inp._tensor, llist(inp._levels), inp._has_device 82 else: 83 return inp, llist(range(-inp.ndim, 0)), True 84 85 86def _match_levels(v, from_levels, to_levels): 87 view = [] 88 permute = [] 89 requires_view = False 90 size = v.size() 91 for t in to_levels: 92 try: 93 idx = from_levels.index(t) 94 permute.append(idx) 95 view.append(size[idx]) 96 except ValueError: 97 view.append(1) 98 requires_view = True 99 if permute != list(range(len(permute))): 100 v = v.permute(*permute) 101 if requires_view: 102 v = v.view(*view) 103 return v 104 105 106# make a single dimension positional but do not permute it, 107# used to do multi-tensor operators where the dim being acted on 108# should not physically move if possible 109def _positional_no_permute(self, dim, expand_dim=False): 110 from . import Tensor 111 112 ptensor, levels = self._tensor, llist(self._levels) 113 try: 114 idx = levels.index(dim) 115 except ValueError: 116 if not expand_dim: 117 raise 118 idx = 0 119 ptensor = ptensor.expand(dim.size, *ptensor.size()) 120 levels.insert(0, 0) 121 idx_batched = 0 122 for i in range(idx): 123 if isinstance(levels[i], int): 124 levels[i] -= 1 125 idx_batched += 1 126 levels[idx] = -idx_batched - 1 127 return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched 128 129 130def seq(a, b): 131 from . import Dim 132 133 if isinstance(a, Dim) != isinstance(b, Dim): 134 return False 135 if isinstance(a, Dim): 136 return a is b 137 else: 138 return a == b 139 140 141class isin: 142 def __contains__(self, item): 143 for x in self: 144 if seq(item, x): 145 return True 146 return False 147 148 def index(self, item): 149 for i, x in enumerate(self): 150 if seq(item, x): 151 return i 152 raise ValueError 153 154 155class llist(isin, list): 156 pass 157 158 159class ltuple(isin, tuple): 160 pass 161 162 163empty_dict = {} 164 165 166@classmethod 167def __torch_function__(self, orig, cls, args, kwargs=empty_dict): 168 from . import _Tensor, Tensor, TensorLike 169 from .delayed_mul_tensor import DelayedMulTensor 170 171 if orig is torch.Tensor.__mul__: 172 lhs, rhs = args 173 if ( 174 isinstance(lhs, _Tensor) 175 and isinstance(rhs, _Tensor) 176 and lhs.ndim == 0 177 and rhs.ndim == 0 178 ): 179 return DelayedMulTensor(lhs, rhs) 180 all_dims = llist() 181 flat_args, unflatten = tree_flatten((args, kwargs)) 182 device_holding_tensor = None 183 for f in flat_args: 184 if isinstance(f, _Tensor): 185 if f._has_device: 186 device_holding_tensor = f._batchtensor 187 for d in f.dims: 188 if d not in all_dims: 189 all_dims.append(d) 190 191 def unwrap(t): 192 if isinstance(t, _Tensor): 193 r = t._batchtensor 194 if device_holding_tensor is not None and not t._has_device: 195 r = r.to(device=device_holding_tensor.device) 196 return r 197 return t 198 199 if orig in pointwise: 200 result_levels = llist() 201 arg_levels = llist() 202 to_expand = [] 203 for i, f in enumerate(flat_args): 204 if isinstance(f, TensorLike): 205 ptensor, levels, _ = _tensor_levels(f) 206 if ( 207 isinstance(f, _Tensor) 208 and not f._has_device 209 and device_holding_tensor is not None 210 ): 211 ptensor = ptensor.to(device=device_holding_tensor.device) 212 flat_args[i] = ptensor 213 for l in levels: 214 if l not in result_levels: 215 result_levels.append(l) 216 to_expand.append((i, levels)) 217 218 for i, levels in to_expand: 219 flat_args[i] = _match_levels(flat_args[i], levels, result_levels) 220 args, kwargs = unflatten(flat_args) 221 result = orig(*args, **kwargs) 222 223 def wrap(t): 224 if isinstance(t, TensorLike): 225 return Tensor.from_positional( 226 t, result_levels, device_holding_tensor is not None 227 ) 228 return t 229 230 return tree_map(wrap, result) 231 else: 232 233 def wrap(t): 234 if isinstance(t, TensorLike): 235 return Tensor.from_batched(t, device_holding_tensor is not None) 236 return t 237 238 with _enable_layers(all_dims): 239 print(f"batch_tensor for {orig}") 240 args, kwargs = unflatten(unwrap(f) for f in flat_args) 241 result = orig(*args, **kwargs) 242 # print("END", orig) 243 return tree_map(wrap, result) 244 245 246def positional(self, *dims): 247 from . import Dim, DimensionBindError, Tensor 248 249 ptensor, levels = self._tensor, llist(self._levels) 250 flat_dims = llist() 251 view = [] 252 needs_view = False 253 ndim = self.ndim 254 for d in dims: 255 if isinstance(d, DimList): 256 flat_dims.extend(d) 257 view.extend(e.size for e in d) 258 elif isinstance(d, Dim): 259 flat_dims.append(d) 260 view.append(d.size) 261 elif isinstance(d, int): 262 d = _wrap_dim(d, ndim, False) 263 flat_dims.append(d) 264 view.append(ptensor.size(d)) 265 else: 266 flat_dims.extend(d) 267 view.append(prod(e.size for e in d)) 268 needs_view = True 269 270 permute = list(range(len(levels))) 271 nflat = len(flat_dims) 272 for i, d in enumerate(flat_dims): 273 try: 274 idx = levels.index(d) 275 except ValueError as e: 276 raise DimensionBindError( 277 f"tensor of dimensions {self.dims} does not contain dim {d}" 278 ) from e 279 p = permute[idx] 280 del levels[idx] 281 del permute[idx] 282 levels.insert(i, 0) 283 permute.insert(i, p) 284 ptensor = ptensor.permute(*permute) 285 seen = 0 286 for i in range(len(levels) - 1, -1, -1): 287 if isinstance(levels[i], int): 288 seen += 1 289 levels[i] = -seen 290 result = Tensor.from_positional(ptensor, levels, self._has_device) 291 if needs_view: 292 result = result.reshape(*view, *result.size()[len(flat_dims) :]) 293 return result 294 295 296def _contains_dim(input): 297 from . import Dim 298 299 for i in input: 300 if isinstance(i, Dim): 301 return True 302 303 304def expand(self, *sizes): 305 if not _contains_dim(sizes): 306 return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) 307 dims = sizes 308 sizes = [d.size for d in dims] + [-1] * self.ndim 309 self = self.expand(*sizes) 310 return self[dims] 311 312 313_not_present = object() 314 315 316def _getarg(name, offset, args, kwargs, default): 317 if len(args) > offset: 318 return args[offset] 319 return kwargs.get(name, default) 320 321 322def _patcharg(name, offset, args, kwargs, value): 323 if len(args) > offset: 324 args[offset] = value 325 else: 326 kwargs[name] = value 327 328 329def _wrap( 330 orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True 331): 332 from . import Dim, Tensor, TensorLike 333 334 def fn(self, *args, **kwargs): 335 dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) 336 if dim is _not_present or (single_dim and not isinstance(dim, Dim)): 337 with _enable_layers(self.dims): 338 print(f"dim fallback batch_tensor for {orig}") 339 return Tensor.from_batched( 340 orig(self._batchtensor, *args, **kwargs), self._has_device 341 ) 342 keepdim = ( 343 _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False 344 ) 345 t, levels = self._tensor, llist(self._levels) 346 dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) 347 dim_indices = tuple(levels.index(d) for d in dims) 348 if reduce and not keepdim: 349 new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] 350 else: 351 new_levels = levels 352 353 if len(dim_indices) == 1: 354 dim_indices = dim_indices[ 355 0 356 ] # so that dims that really only take a single argument work... 357 args = list(args) 358 _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) 359 360 def wrap(t): 361 if isinstance(t, TensorLike): 362 return Tensor.from_positional(t, new_levels, self._has_device) 363 return t 364 365 with _enable_layers(new_levels): 366 print(f"dim used batch_tensor for {orig}") 367 r = orig(t, *args, **kwargs) 368 return tree_map(wrap, r) 369 370 return fn 371 372 373def _def(name, *args, **kwargs): 374 from . import _Tensor 375 376 orig = getattr(torch.Tensor, name) 377 setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) 378 379 380no_slice = slice(None) 381 382_orig_getitem = torch.Tensor.__getitem__ 383 384 385class dim_tracker: 386 def __init__(self) -> None: 387 self.dims = llist() 388 self.count = [] 389 390 def record(self, d): 391 if d not in self.dims: 392 self.dims.append(d) 393 self.count.append(1) 394 395 def __getitem__(self, d): 396 return self.count[self.dims.index(d)] 397 398 399def t__getitem__(self, input): 400 from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike 401 402 # * bail to original example if we have a single non-Dim tensor, or a non-tensor 403 # * locate ... or an unbound tensor list, and determine its size, bind dim list 404 # (remember that None does not count to the total dim count) 405 # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, 406 # produce the re-view if needed 407 # * for each single-use dim index, replace with no_slice and mark that it will be added 408 # (keep track of whether we have to call super) 409 # * call super if needed 410 # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) 411 # this handles bool indexing handling, as well as some other simple cases. 412 413 is_simple = ( 414 not isinstance(input, Dim) 415 and not isinstance(input, (tuple, list)) 416 and 417 # WAR for functorch bug where zero time tensors in getitem are not handled correctly. 418 not (isinstance(input, TensorLike) and input.ndim == 0) 419 ) 420 421 if is_simple: 422 if isinstance(self, _Tensor): 423 return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) 424 else: 425 return _orig_getitem(self, input) 426 427 # can further optimize this case 428 if not isinstance(input, tuple): 429 input = [input] 430 else: 431 input = list(input) 432 433 dims_indexed = 0 434 expanding_object = None 435 dimlists = [] 436 for i, s in enumerate(input): 437 if s is ... or isinstance(s, DimList) and not s.is_bound: 438 if expanding_object is not None: 439 msg = ( 440 "at most one ... or unbound dimension list can exist in indexing list but" 441 f" found 2 at offsets {i} and {expanding_object}" 442 ) 443 raise DimensionBindError(msg) 444 expanding_object = i 445 446 if isinstance(s, DimList): 447 dims_indexed += len(s) if s.is_bound else 0 448 dimlists.append(i) 449 elif s is not None and s is not ...: 450 dims_indexed += 1 451 452 ndim = self.ndim 453 if dims_indexed > ndim: 454 raise IndexError( 455 f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." 456 ) 457 if expanding_object is not None: 458 expanding_ndims = ndim - dims_indexed 459 obj = input[expanding_object] 460 if obj is ...: 461 input[expanding_object : expanding_object + 1] = [ 462 no_slice 463 ] * expanding_ndims 464 else: 465 obj.bind_len(expanding_ndims) 466 # flatten the dimslists into the indexing 467 for i in reversed(dimlists): 468 input[i : i + 1] = input[i] 469 dims_indexed = 0 470 requires_view = False 471 size = self.size() 472 view_sizes = [] 473 dims_seen = dim_tracker() 474 475 def add_dims(t): 476 if not isinstance(t, _Tensor): 477 return 478 for d in t.dims: 479 dims_seen.record(d) 480 481 add_dims(self) 482 dim_packs = [] 483 for i, idx in enumerate(input): 484 if idx is None: 485 input[i] = no_slice 486 view_sizes.append(1) 487 requires_view = True 488 else: 489 sz = size[dims_indexed] 490 if isinstance(idx, Dim): 491 idx.size = sz 492 dims_seen.record(idx) 493 view_sizes.append(sz) 494 elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): 495 for d in idx: 496 dims_seen.record(idx) 497 _bind_dims_to_size(sz, idx, f"offset {i}") 498 view_sizes.extend(d.size for d in idx) 499 requires_view = True 500 dim_packs.append(i) 501 else: 502 add_dims(idx) 503 view_sizes.append(sz) 504 dims_indexed += 1 505 if requires_view: 506 self = self.view(*view_sizes) 507 for i in reversed(dim_packs): 508 input[i : i + 1] = input[i] 509 510 # currenty: 511 # input is flat, containing either Dim, or Tensor, or something valid for standard indexing 512 # self may have first-class dims as well. 513 514 # to index: 515 # drop the first class dims from self, they just become direct indices of their positions 516 517 # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. 518 # these dimensions will appear and need to be bound at the first place tensor occures 519 520 if isinstance(self, _Tensor): 521 ptensor_self, levels = self._tensor, list(self._levels) 522 # indices to ptensor rather than self which has first-class dimensions 523 input_it = iter(input) 524 flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] 525 has_device = self._has_device 526 to_pad = 0 527 else: 528 ptensor_self, flat_inputs = self, input 529 to_pad = ptensor_self.ndim - len(flat_inputs) 530 has_device = True 531 532 result_levels = [] 533 index_levels = [] 534 tensor_insert_point = None 535 to_expand = {} 536 requires_getindex = False 537 for i, inp in enumerate(flat_inputs): 538 if isinstance(inp, Dim) and dims_seen[inp] == 1: 539 flat_inputs[i] = no_slice 540 result_levels.append(inp) 541 elif isinstance(inp, TensorLike): 542 requires_getindex = True 543 if tensor_insert_point is None: 544 tensor_insert_point = len(result_levels) 545 ptensor, levels, _ = _tensor_levels(inp) 546 to_expand[i] = levels 547 flat_inputs[i] = ptensor 548 for l in levels: 549 if l not in index_levels: 550 index_levels.append(l) 551 else: 552 requires_getindex = True 553 result_levels.append(0) 554 555 if tensor_insert_point is not None: 556 result_levels[tensor_insert_point:tensor_insert_point] = index_levels 557 558 for i, levels in to_expand.items(): 559 flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) 560 561 if requires_getindex: 562 result = _orig_getitem(ptensor_self, flat_inputs) 563 else: 564 result = ptensor_self 565 566 next_positional = -1 567 if to_pad > 0: 568 result_levels.extend([0] * to_pad) 569 for i, r in enumerate(reversed(result_levels)): 570 if isinstance(r, int): 571 result_levels[-1 - i] = next_positional 572 next_positional -= 1 573 574 return Tensor.from_positional(result, result_levels, has_device) 575 576 577# XXX - dim is optional and can be the outer-most dimension... 578def stack(tensors, new_dim, dim=0, out=None): 579 if isinstance(dim, int): 580 return torch.stack(tensors, dim, out).index(dim, new_dim) 581 index = None 582 if out is not None: 583 out, index = _positional_no_permute(out, dim, expand_dim=True) 584 ptensors = [] 585 for t in tensors: 586 pt, pi = _positional_no_permute(t, dim, expand_dim=True) 587 if index is not None and pi != index: 588 pt = pt.move_dim(pi, index) 589 else: 590 index = pi 591 ptensors.append(pt) 592 pr = torch.stack(ptensors, index, out=out) 593 return pr.index((index, index + 1), (new_dim, dim)) 594 595 596_orig_split = torch.Tensor.split 597 598 599def split(self, split_size_or_sections, dim=0): 600 from . import _Tensor, Dim 601 602 if isinstance(split_size_or_sections, int) or any( 603 isinstance(t, int) for t in split_size_or_sections 604 ): 605 if isinstance(dim, Dim): 606 raise ValueError( 607 "when dim is specified as a Dim object, split sizes must also be dimensions." 608 ) 609 return _orig_split(self, split_size_or_sections, dim=dim) 610 611 if isinstance(dim, Dim): 612 assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" 613 self, dim = _positional_no_permute(self, dim) 614 615 size = self.size(dim) 616 total_bound_size = 0 617 unbound = [] 618 sizes = [] 619 for i, d in enumerate(split_size_or_sections): 620 if d.is_bound: 621 sizes.append(d.size) 622 total_bound_size += d.size 623 else: 624 sizes.append(0) 625 unbound.append(i) 626 627 if unbound: 628 assert ( 629 total_bound_size <= size 630 ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" 631 remaining_size = size - total_bound_size 632 chunk_size = -(-remaining_size // len(unbound)) 633 for u in unbound: 634 sz = min(chunk_size, remaining_size) 635 split_size_or_sections[u].size = sz 636 sizes[u] = sz 637 remaining_size -= sz 638 else: 639 assert ( 640 total_bound_size == size 641 ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" 642 return tuple( 643 t.index(dim, d) 644 for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) 645 ) 646