1# Owner(s): ["module: scatter & gather ops"] 2 3from itertools import product 4from functools import partial 5 6import numpy as np 7import torch 8from torch.testing._internal.common_device_type import ( 9 instantiate_device_type_tests, 10 dtypes, 11) 12from torch.testing._internal.common_utils import ( 13 TestCase, 14 run_tests, 15 gradcheck, 16 parametrize, 17 skipIfRocm, 18) 19 20 21reductions = ["max", "mean", "min", "sum", "prod"] 22 23 24def get_default_value(initial_value, reduction): 25 if initial_value is not None: 26 return initial_value 27 if reduction == "max": 28 return -float("Inf") 29 elif reduction == "mean": 30 return float("nan") 31 elif reduction == "min": 32 return float("Inf") 33 elif reduction == "sum": 34 return 0.0 35 elif reduction == "prod": 36 return 1.0 37 38 39class TestSegmentReductions(TestCase): 40 def _test_common( 41 self, 42 reduction, 43 device, 44 dtype, 45 unsafe, 46 axis, 47 initial_value, 48 data_arr, 49 lengths_arr, 50 expected_arr, 51 expected_grad_arr, 52 check_backward, 53 lengths_dtype=torch.int, 54 ): 55 lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype) 56 # generate offsets from lengths 57 zeros_shape = list(lengths.shape) 58 zeros_shape[-1] = 1 59 offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1) 60 61 data = torch.tensor( 62 data_arr, 63 device=device, 64 dtype=dtype, 65 requires_grad=True, 66 ) 67 expected_result = torch.tensor(expected_arr, device=device, dtype=dtype) 68 expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype) 69 for mode in ['lengths', 'offsets']: 70 segment_reduce_kwargs = dict( 71 axis=axis, 72 unsafe=unsafe, 73 initial=initial_value) 74 if (mode == 'lengths'): 75 segment_reduce_kwargs['lengths'] = lengths 76 else: 77 segment_reduce_kwargs['offsets'] = offsets 78 actual_result = torch._segment_reduce( 79 data=data, 80 reduce=reduction, 81 **segment_reduce_kwargs 82 ) 83 self.assertEqual( 84 expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True 85 ) 86 87 if not check_backward: 88 return 89 90 # Test backward 91 actual_result.sum().backward() 92 self.assertEqual( 93 expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True 94 ) 95 data = data.clone().detach().requires_grad_(True) 96 97 # gradcheck does not work well with bfloat16 or fp16 cpu types 98 # also there is small numerical difference with fp32 99 if dtype not in [torch.half, torch.bfloat16, torch.float]: 100 # gradcheck does not like "nan" input, setting to random 10 101 d_non_nan = np.nan_to_num(data_arr, nan=10) 102 new_data = torch.tensor( 103 # [10 if v == float("nan") else v for v in data], 104 d_non_nan, 105 device=device, 106 dtype=dtype, 107 requires_grad=True, 108 ) 109 self.assertTrue( 110 gradcheck( 111 lambda x: torch._segment_reduce( 112 data=x, 113 reduce=reduction, 114 **segment_reduce_kwargs 115 ), 116 (new_data,), 117 ) 118 ) 119 120 @dtypes( 121 *product( 122 (torch.half, torch.bfloat16, torch.float, torch.double), 123 (torch.int, torch.int64), 124 ) 125 ) 126 def test_simple_1d(self, device, dtypes): 127 val_dtype, length_type = dtypes 128 lengths = [1, 2, 3, 0] 129 data = [1, float("nan"), 3, 4, 5, 5] 130 131 for reduction in reductions: 132 for initial in [0, None]: 133 check_backward = True if initial is not None else False 134 initial_value = initial 135 default_value = get_default_value(initial_value, reduction) 136 if reduction == "max": 137 expected_result = [1, float("nan"), 5, default_value] 138 expected_grad = [1, 1, 0, 0, 0.5, 0.5] 139 elif reduction == "mean": 140 expected_result = [1, float("nan"), 4.666, default_value] 141 expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333] 142 elif reduction == "min": 143 if initial is not None: 144 initial_value = 1000 # some high number 145 default_value = get_default_value(initial_value, reduction) 146 expected_result = [1, float("nan"), 4, default_value] 147 expected_grad = [1.0, 1.0, 0, 1, 0, 0] 148 elif reduction == "sum": 149 expected_result = [1, float("nan"), 14, default_value] 150 expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 151 elif reduction == "prod": 152 if initial is not None: 153 initial_value = 2 # 0 initial_value will zero out everything for prod 154 default_value = get_default_value(initial_value, reduction) 155 expected_result = [2, float("nan"), 200, default_value] 156 expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0] 157 else: 158 expected_result = [1, float("nan"), 100, default_value] 159 expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0] 160 for axis in [0, -1]: 161 for unsafe in [True, False]: 162 self._test_common( 163 reduction, 164 device, 165 val_dtype, 166 unsafe, 167 axis, 168 initial_value, 169 data, 170 lengths, 171 expected_result, 172 expected_grad, 173 check_backward, 174 length_type, 175 ) 176 177 @dtypes( 178 *product( 179 (torch.half, torch.bfloat16, torch.float, torch.double), 180 (torch.int, torch.int64), 181 ) 182 ) 183 def test_simple_zero_length(self, device, dtypes): 184 val_dtype, length_type = dtypes 185 lengths = [0, 0] 186 data = torch.ones(0) 187 188 for reduction in reductions: 189 for initial in [0, None]: 190 check_backward = True if initial is not None else False 191 initial_value = initial 192 default_value = get_default_value(initial_value, reduction) 193 if reduction == "max": 194 expected_result = [default_value, default_value] 195 expected_grad = [] 196 elif reduction == "mean": 197 expected_result = [default_value, default_value] 198 expected_grad = [] 199 elif reduction == "min": 200 if initial is not None: 201 initial_value = 1000 # some high number 202 default_value = get_default_value(initial_value, reduction) 203 expected_result = [default_value, default_value] 204 expected_grad = [] 205 elif reduction == "sum": 206 expected_result = [default_value, default_value] 207 expected_grad = [] 208 elif reduction == "prod": 209 if initial is not None: 210 initial_value = 2 # 0 initial_value will zero out everything for prod 211 default_value = get_default_value(initial_value, reduction) 212 expected_result = [default_value, default_value] 213 expected_grad = [] 214 else: 215 expected_result = [default_value, default_value] 216 expected_grad = [] 217 for axis in [0]: 218 for unsafe in [True, False]: 219 self._test_common( 220 reduction, 221 device, 222 val_dtype, 223 unsafe, 224 axis, 225 initial_value, 226 data, 227 lengths, 228 expected_result, 229 expected_grad, 230 check_backward, 231 length_type, 232 ) 233 234 @skipIfRocm 235 @dtypes( 236 *product( 237 (torch.half, torch.bfloat16, torch.float, torch.double), 238 (torch.int, torch.int64), 239 ) 240 ) 241 def test_multi_d_simple(self, device, dtypes): 242 val_dtype, length_type = dtypes 243 axis = 0 244 lengths = [1, 2, 3, 0] 245 data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]] 246 247 for reduction in reductions: 248 for initial in [0, None]: 249 check_backward = True if initial is not None else False 250 initial_value = initial 251 default_value = get_default_value(initial_value, reduction) 252 if reduction == "max": 253 expected_result = [ 254 [1, 1], 255 [float("nan"), float("nan")], 256 [4, 3], 257 [default_value, default_value], 258 ] 259 expected_grad = [ 260 [1, 1], 261 [1, 0], 262 [0, 1], 263 [1, 0], 264 [0, 0], 265 [0, 1], 266 ] 267 elif reduction == "mean": 268 expected_result = [ 269 [1, 1], 270 [float("nan"), float("nan")], 271 [3, 2], 272 [default_value, default_value], 273 ] 274 expected_grad = [ 275 [1.0, 1.0], 276 [0.5, 0.5], 277 [0.5, 0.5], 278 [0.333, 0.333], 279 [0.333, 0.333], 280 [0.333, 0.333], 281 ] 282 elif reduction == "min": 283 if initial is not None: 284 initial_value = 1000 # some high number 285 default_value = get_default_value(initial_value, reduction) 286 expected_result = [ 287 [1, 1], 288 [float("nan"), float("nan")], 289 [2, 1], 290 [default_value, default_value], 291 ] 292 expected_grad = [ 293 [1.0, 1.0], 294 [1, 0], 295 [0, 1], 296 [0, 1], 297 [0, 0], 298 [1, 0], 299 ] 300 elif reduction == "sum": 301 expected_result = [ 302 [1, 1], 303 [float("nan"), float("nan")], 304 [9, 6], 305 [default_value, default_value], 306 ] 307 expected_grad = [ 308 [1.0, 1.0], 309 [1.0, 1.0], 310 [1.0, 1.0], 311 [1.0, 1.0], 312 [1.0, 1.0], 313 [1.0, 1.0], 314 ] 315 elif reduction == "prod": 316 if initial is not None: 317 initial_value = 2 # 0 initial_value will zero out everything for prod 318 default_value = get_default_value(initial_value, reduction) 319 expected_result = [ 320 [2, 2], 321 [float("nan"), float("nan")], 322 [48, 12], 323 [default_value, default_value], 324 ] 325 expected_grad = [ 326 [2.0, 2.0], 327 [6.0, float("nan")], 328 [float("nan"), 2.0], 329 [12.0, 12.0], 330 [16.0, 6.0], 331 [24.0, 4.0], 332 ] 333 else: 334 expected_result = [ 335 [1, 1], 336 [float("nan"), float("nan")], 337 [24, 6], 338 [default_value, default_value], 339 ] 340 expected_grad = [ 341 [1.0, 1.0], 342 [3.0, float("nan")], 343 [float("nan"), 1.0], 344 [6.0, 6.0], 345 [8.0, 3.0], 346 [12.0, 2.0], 347 ] 348 for unsafe in [True, False]: 349 self._test_common( 350 reduction, 351 device, 352 val_dtype, 353 unsafe, 354 axis, 355 initial_value, 356 data, 357 lengths, 358 expected_result, 359 expected_grad, 360 check_backward, 361 ) 362 363 @dtypes( 364 *product( 365 (torch.half, torch.bfloat16, torch.float, torch.double), 366 (torch.int, torch.int64), 367 ) 368 ) 369 @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean']) 370 def test_pytorch_scatter_test_cases(self, device, dtypes, reduce): 371 val_dtype, length_dtype = dtypes 372 # zero-length segments are filled with reduction inits contrary to pytorch_scatter. 373 tests = [ 374 { 375 'src': [1, 2, 3, 4, 5, 6], 376 'index': [0, 0, 1, 1, 1, 3], 377 'indptr': [0, 2, 5, 5, 6], 378 'sum': [3, 12, 0, 6], 379 'prod': [2, 60, 1, 6], 380 'mean': [1.5, 4, float('nan'), 6], 381 'min': [1, 3, float('inf'), 6], 382 'max': [2, 5, -float('inf'), 6], 383 }, 384 { 385 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 386 'index': [0, 0, 1, 1, 1, 3], 387 'indptr': [0, 2, 5, 5, 6], 388 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], 389 'prod': [[3, 8], [315, 480], [1, 1], [11, 12]], 390 'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]], 391 'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]], 392 'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]], 393 }, 394 { 395 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], 396 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], 397 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], 398 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], 399 'prod': [[3, 315, 1, 11], [48, 80, 12, 1]], 400 'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]], 401 'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]], 402 'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]], 403 }, 404 { 405 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], 406 'index': [[0, 0, 1], [0, 2, 2]], 407 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], 408 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 409 'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]], 410 'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]], 411 [[7, 9], [float('nan'), float('nan')], [11, 12]]], 412 'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]], 413 [[7, 9], [float('inf'), float('inf')], [10, 11]]], 414 'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]], 415 [[7, 9], [-float('inf'), -float('inf')], [12, 13]]], 416 }, 417 { 418 'src': [[1, 3], [2, 4]], 419 'index': [[0, 0], [0, 0]], 420 'indptr': [[0, 2], [0, 2]], 421 'sum': [[4], [6]], 422 'prod': [[3], [8]], 423 'mean': [[2], [3]], 424 'min': [[1], [2]], 425 'max': [[3], [4]], 426 }, 427 { 428 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], 429 'index': [[0, 0], [0, 0]], 430 'indptr': [[0, 2], [0, 2]], 431 'sum': [[[4, 4]], [[6, 6]]], 432 'prod': [[[3, 3]], [[8, 8]]], 433 'mean': [[[2, 2]], [[3, 3]]], 434 'min': [[[1, 1]], [[2, 2]]], 435 'max': [[[3, 3]], [[4, 4]]], 436 }, 437 ] 438 for test in tests: 439 data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True) 440 indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device) 441 dim = indptr.ndim - 1 442 # calculate lengths from indptr 443 lengths = torch.diff(indptr, dim=dim) 444 expected = torch.tensor(test[reduce], dtype=val_dtype, device=device) 445 446 actual_result = torch._segment_reduce( 447 data=data, 448 reduce=reduce, 449 lengths=lengths, 450 axis=dim, 451 unsafe=True, 452 ) 453 self.assertEqual(actual_result, expected) 454 455 # test offsets 456 actual_result = torch._segment_reduce( 457 data=data, 458 reduce=reduce, 459 offsets=indptr, 460 axis=dim, 461 unsafe=True, 462 ) 463 self.assertEqual(actual_result, expected) 464 465 if val_dtype == torch.float64: 466 def fn(x, mode='lengths'): 467 initial = 1 468 # supply initial values to prevent gradcheck from failing for 0 length segments 469 # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian 470 if reduce == 'min': 471 initial = 1000 472 elif reduce == 'max': 473 initial = -1000 474 segment_reduce_args = {x, reduce} 475 segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial) 476 if mode == 'lengths': 477 segment_reduce_kwargs[mode] = lengths 478 elif mode == 'offsets': 479 segment_reduce_kwargs[mode] = indptr 480 return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs) 481 self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True)))) 482 self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True)))) 483 484 485 @dtypes( 486 *product( 487 (torch.half, torch.bfloat16, torch.float, torch.double), 488 (torch.int, torch.int64), 489 ) 490 ) 491 def test_multi_d(self, device, dtypes): 492 val_dtype, length_type = dtypes 493 axis = 0 494 lengths = [0, 2, 3, 0] 495 data = np.arange(50).reshape(5, 2, 5).tolist() 496 expected_grad = [] 497 498 # TODO: calculate grad and check correctness 499 check_backward = False 500 501 for reduction in reductions: 502 initial_value = 0 503 if reduction == "max": 504 expected_result = [ 505 np.full((2, 5), initial_value).tolist(), 506 np.max(data[:2], axis=0).tolist(), 507 np.max(data[2:], axis=0).tolist(), 508 np.full((2, 5), initial_value).tolist(), 509 ] 510 elif reduction == "mean": 511 expected_result = [ 512 np.full((2, 5), initial_value).tolist(), 513 np.mean(data[:2], axis=0).tolist(), 514 np.mean(data[2:], axis=0).tolist(), 515 np.full((2, 5), initial_value).tolist(), 516 ] 517 elif reduction == "min": 518 initial_value = 1000 # some high number 519 expected_result = [ 520 np.full((2, 5), initial_value).tolist(), 521 np.min(data[:2], axis=0).tolist(), 522 np.min(data[2:], axis=0).tolist(), 523 np.full((2, 5), initial_value).tolist(), 524 ] 525 elif reduction == "sum": 526 expected_result = [ 527 np.full((2, 5), initial_value).tolist(), 528 np.sum(data[:2], axis=0).tolist(), 529 np.sum(data[2:], axis=0).tolist(), 530 np.full((2, 5), initial_value).tolist(), 531 ] 532 elif reduction == "prod": 533 initial_value = 1 534 expected_result = [ 535 np.full((2, 5), initial_value).tolist(), 536 np.prod(data[:2], axis=0).tolist(), 537 np.prod(data[2:], axis=0).tolist(), 538 np.full((2, 5), initial_value).tolist(), 539 ] 540 for unsafe in [True, False]: 541 self._test_common( 542 reduction, 543 device, 544 val_dtype, 545 unsafe, 546 axis, 547 initial_value, 548 data, 549 lengths, 550 expected_result, 551 expected_grad, 552 check_backward, 553 ) 554 555 @dtypes(torch.int, torch.int64) 556 def test_unsafe_flag(self, device, dtype): 557 length_type = dtype 558 lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type) 559 data = torch.arange(6, dtype=torch.float, device=device) 560 561 # test for error on 1-D lenghts 562 with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): 563 torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False) 564 565 # test for error on multi-D lengths 566 nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device) 567 nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6) 568 with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): 569 torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False) 570 571 572 573 574instantiate_device_type_tests(TestSegmentReductions, globals()) 575 576if __name__ == "__main__": 577 run_tests() 578