xref: /aosp_15_r20/external/pytorch/test/test_segment_reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: scatter & gather ops"]
2
3from itertools import product
4from functools import partial
5
6import numpy as np
7import torch
8from torch.testing._internal.common_device_type import (
9    instantiate_device_type_tests,
10    dtypes,
11)
12from torch.testing._internal.common_utils import (
13    TestCase,
14    run_tests,
15    gradcheck,
16    parametrize,
17    skipIfRocm,
18)
19
20
21reductions = ["max", "mean", "min", "sum", "prod"]
22
23
24def get_default_value(initial_value, reduction):
25    if initial_value is not None:
26        return initial_value
27    if reduction == "max":
28        return -float("Inf")
29    elif reduction == "mean":
30        return float("nan")
31    elif reduction == "min":
32        return float("Inf")
33    elif reduction == "sum":
34        return 0.0
35    elif reduction == "prod":
36        return 1.0
37
38
39class TestSegmentReductions(TestCase):
40    def _test_common(
41        self,
42        reduction,
43        device,
44        dtype,
45        unsafe,
46        axis,
47        initial_value,
48        data_arr,
49        lengths_arr,
50        expected_arr,
51        expected_grad_arr,
52        check_backward,
53        lengths_dtype=torch.int,
54    ):
55        lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
56        # generate offsets from lengths
57        zeros_shape = list(lengths.shape)
58        zeros_shape[-1] = 1
59        offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
60
61        data = torch.tensor(
62            data_arr,
63            device=device,
64            dtype=dtype,
65            requires_grad=True,
66        )
67        expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
68        expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
69        for mode in ['lengths', 'offsets']:
70            segment_reduce_kwargs = dict(
71                axis=axis,
72                unsafe=unsafe,
73                initial=initial_value)
74            if (mode == 'lengths'):
75                segment_reduce_kwargs['lengths'] = lengths
76            else:
77                segment_reduce_kwargs['offsets'] = offsets
78            actual_result = torch._segment_reduce(
79                data=data,
80                reduce=reduction,
81                **segment_reduce_kwargs
82            )
83            self.assertEqual(
84                expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
85            )
86
87            if not check_backward:
88                return
89
90            # Test backward
91            actual_result.sum().backward()
92            self.assertEqual(
93                expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
94            )
95            data = data.clone().detach().requires_grad_(True)
96
97            # gradcheck does not work well with bfloat16 or fp16 cpu types
98            # also there is small numerical difference with fp32
99            if dtype not in [torch.half, torch.bfloat16, torch.float]:
100                # gradcheck does not like "nan" input, setting to random 10
101                d_non_nan = np.nan_to_num(data_arr, nan=10)
102                new_data = torch.tensor(
103                    # [10 if v == float("nan") else v for v in data],
104                    d_non_nan,
105                    device=device,
106                    dtype=dtype,
107                    requires_grad=True,
108                )
109                self.assertTrue(
110                    gradcheck(
111                        lambda x: torch._segment_reduce(
112                            data=x,
113                            reduce=reduction,
114                            **segment_reduce_kwargs
115                        ),
116                        (new_data,),
117                    )
118                )
119
120    @dtypes(
121        *product(
122            (torch.half, torch.bfloat16, torch.float, torch.double),
123            (torch.int, torch.int64),
124        )
125    )
126    def test_simple_1d(self, device, dtypes):
127        val_dtype, length_type = dtypes
128        lengths = [1, 2, 3, 0]
129        data = [1, float("nan"), 3, 4, 5, 5]
130
131        for reduction in reductions:
132            for initial in [0, None]:
133                check_backward = True if initial is not None else False
134                initial_value = initial
135                default_value = get_default_value(initial_value, reduction)
136                if reduction == "max":
137                    expected_result = [1, float("nan"), 5, default_value]
138                    expected_grad = [1, 1, 0, 0, 0.5, 0.5]
139                elif reduction == "mean":
140                    expected_result = [1, float("nan"), 4.666, default_value]
141                    expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
142                elif reduction == "min":
143                    if initial is not None:
144                        initial_value = 1000  # some high number
145                        default_value = get_default_value(initial_value, reduction)
146                    expected_result = [1, float("nan"), 4, default_value]
147                    expected_grad = [1.0, 1.0, 0, 1, 0, 0]
148                elif reduction == "sum":
149                    expected_result = [1, float("nan"), 14, default_value]
150                    expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
151                elif reduction == "prod":
152                    if initial is not None:
153                        initial_value = 2  # 0 initial_value will zero out everything for prod
154                        default_value = get_default_value(initial_value, reduction)
155                        expected_result = [2, float("nan"), 200, default_value]
156                        expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
157                    else:
158                        expected_result = [1, float("nan"), 100, default_value]
159                        expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
160                for axis in [0, -1]:
161                    for unsafe in [True, False]:
162                        self._test_common(
163                            reduction,
164                            device,
165                            val_dtype,
166                            unsafe,
167                            axis,
168                            initial_value,
169                            data,
170                            lengths,
171                            expected_result,
172                            expected_grad,
173                            check_backward,
174                            length_type,
175                        )
176
177    @dtypes(
178        *product(
179            (torch.half, torch.bfloat16, torch.float, torch.double),
180            (torch.int, torch.int64),
181        )
182    )
183    def test_simple_zero_length(self, device, dtypes):
184        val_dtype, length_type = dtypes
185        lengths = [0, 0]
186        data = torch.ones(0)
187
188        for reduction in reductions:
189            for initial in [0, None]:
190                check_backward = True if initial is not None else False
191                initial_value = initial
192                default_value = get_default_value(initial_value, reduction)
193                if reduction == "max":
194                    expected_result = [default_value, default_value]
195                    expected_grad = []
196                elif reduction == "mean":
197                    expected_result = [default_value, default_value]
198                    expected_grad = []
199                elif reduction == "min":
200                    if initial is not None:
201                        initial_value = 1000  # some high number
202                        default_value = get_default_value(initial_value, reduction)
203                    expected_result = [default_value, default_value]
204                    expected_grad = []
205                elif reduction == "sum":
206                    expected_result = [default_value, default_value]
207                    expected_grad = []
208                elif reduction == "prod":
209                    if initial is not None:
210                        initial_value = 2  # 0 initial_value will zero out everything for prod
211                        default_value = get_default_value(initial_value, reduction)
212                        expected_result = [default_value, default_value]
213                        expected_grad = []
214                    else:
215                        expected_result = [default_value, default_value]
216                        expected_grad = []
217                for axis in [0]:
218                    for unsafe in [True, False]:
219                        self._test_common(
220                            reduction,
221                            device,
222                            val_dtype,
223                            unsafe,
224                            axis,
225                            initial_value,
226                            data,
227                            lengths,
228                            expected_result,
229                            expected_grad,
230                            check_backward,
231                            length_type,
232                        )
233
234    @skipIfRocm
235    @dtypes(
236        *product(
237            (torch.half, torch.bfloat16, torch.float, torch.double),
238            (torch.int, torch.int64),
239        )
240    )
241    def test_multi_d_simple(self, device, dtypes):
242        val_dtype, length_type = dtypes
243        axis = 0
244        lengths = [1, 2, 3, 0]
245        data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
246
247        for reduction in reductions:
248            for initial in [0, None]:
249                check_backward = True if initial is not None else False
250                initial_value = initial
251                default_value = get_default_value(initial_value, reduction)
252                if reduction == "max":
253                    expected_result = [
254                        [1, 1],
255                        [float("nan"), float("nan")],
256                        [4, 3],
257                        [default_value, default_value],
258                    ]
259                    expected_grad = [
260                        [1, 1],
261                        [1, 0],
262                        [0, 1],
263                        [1, 0],
264                        [0, 0],
265                        [0, 1],
266                    ]
267                elif reduction == "mean":
268                    expected_result = [
269                        [1, 1],
270                        [float("nan"), float("nan")],
271                        [3, 2],
272                        [default_value, default_value],
273                    ]
274                    expected_grad = [
275                        [1.0, 1.0],
276                        [0.5, 0.5],
277                        [0.5, 0.5],
278                        [0.333, 0.333],
279                        [0.333, 0.333],
280                        [0.333, 0.333],
281                    ]
282                elif reduction == "min":
283                    if initial is not None:
284                        initial_value = 1000  # some high number
285                        default_value = get_default_value(initial_value, reduction)
286                    expected_result = [
287                        [1, 1],
288                        [float("nan"), float("nan")],
289                        [2, 1],
290                        [default_value, default_value],
291                    ]
292                    expected_grad = [
293                        [1.0, 1.0],
294                        [1, 0],
295                        [0, 1],
296                        [0, 1],
297                        [0, 0],
298                        [1, 0],
299                    ]
300                elif reduction == "sum":
301                    expected_result = [
302                        [1, 1],
303                        [float("nan"), float("nan")],
304                        [9, 6],
305                        [default_value, default_value],
306                    ]
307                    expected_grad = [
308                        [1.0, 1.0],
309                        [1.0, 1.0],
310                        [1.0, 1.0],
311                        [1.0, 1.0],
312                        [1.0, 1.0],
313                        [1.0, 1.0],
314                    ]
315                elif reduction == "prod":
316                    if initial is not None:
317                        initial_value = 2  # 0 initial_value will zero out everything for prod
318                        default_value = get_default_value(initial_value, reduction)
319                        expected_result = [
320                            [2, 2],
321                            [float("nan"), float("nan")],
322                            [48, 12],
323                            [default_value, default_value],
324                        ]
325                        expected_grad = [
326                            [2.0, 2.0],
327                            [6.0, float("nan")],
328                            [float("nan"), 2.0],
329                            [12.0, 12.0],
330                            [16.0, 6.0],
331                            [24.0, 4.0],
332                        ]
333                    else:
334                        expected_result = [
335                            [1, 1],
336                            [float("nan"), float("nan")],
337                            [24, 6],
338                            [default_value, default_value],
339                        ]
340                        expected_grad = [
341                            [1.0, 1.0],
342                            [3.0, float("nan")],
343                            [float("nan"), 1.0],
344                            [6.0, 6.0],
345                            [8.0, 3.0],
346                            [12.0, 2.0],
347                        ]
348                for unsafe in [True, False]:
349                    self._test_common(
350                        reduction,
351                        device,
352                        val_dtype,
353                        unsafe,
354                        axis,
355                        initial_value,
356                        data,
357                        lengths,
358                        expected_result,
359                        expected_grad,
360                        check_backward,
361                    )
362
363    @dtypes(
364        *product(
365            (torch.half, torch.bfloat16, torch.float, torch.double),
366            (torch.int, torch.int64),
367        )
368    )
369    @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
370    def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
371        val_dtype, length_dtype = dtypes
372        # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
373        tests = [
374            {
375                'src': [1, 2, 3, 4, 5, 6],
376                'index': [0, 0, 1, 1, 1, 3],
377                'indptr': [0, 2, 5, 5, 6],
378                'sum': [3, 12, 0, 6],
379                'prod': [2, 60, 1, 6],
380                'mean': [1.5, 4, float('nan'), 6],
381                'min': [1, 3, float('inf'), 6],
382                'max': [2, 5, -float('inf'), 6],
383            },
384            {
385                'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
386                'index': [0, 0, 1, 1, 1, 3],
387                'indptr': [0, 2, 5, 5, 6],
388                'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
389                'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
390                'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
391                'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
392                'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
393            },
394            {
395                'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
396                'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
397                'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
398                'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
399                'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
400                'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
401                'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
402                'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
403            },
404            {
405                'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
406                'index': [[0, 0, 1], [0, 2, 2]],
407                'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
408                'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
409                'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
410                'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
411                         [[7, 9], [float('nan'), float('nan')], [11, 12]]],
412                'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
413                        [[7, 9], [float('inf'), float('inf')], [10, 11]]],
414                'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
415                        [[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
416            },
417            {
418                'src': [[1, 3], [2, 4]],
419                'index': [[0, 0], [0, 0]],
420                'indptr': [[0, 2], [0, 2]],
421                'sum': [[4], [6]],
422                'prod': [[3], [8]],
423                'mean': [[2], [3]],
424                'min': [[1], [2]],
425                'max': [[3], [4]],
426            },
427            {
428                'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
429                'index': [[0, 0], [0, 0]],
430                'indptr': [[0, 2], [0, 2]],
431                'sum': [[[4, 4]], [[6, 6]]],
432                'prod': [[[3, 3]], [[8, 8]]],
433                'mean': [[[2, 2]], [[3, 3]]],
434                'min': [[[1, 1]], [[2, 2]]],
435                'max': [[[3, 3]], [[4, 4]]],
436            },
437        ]
438        for test in tests:
439            data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True)
440            indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device)
441            dim = indptr.ndim - 1
442            # calculate lengths from indptr
443            lengths = torch.diff(indptr, dim=dim)
444            expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)
445
446            actual_result = torch._segment_reduce(
447                data=data,
448                reduce=reduce,
449                lengths=lengths,
450                axis=dim,
451                unsafe=True,
452            )
453            self.assertEqual(actual_result, expected)
454
455            # test offsets
456            actual_result = torch._segment_reduce(
457                data=data,
458                reduce=reduce,
459                offsets=indptr,
460                axis=dim,
461                unsafe=True,
462            )
463            self.assertEqual(actual_result, expected)
464
465            if val_dtype == torch.float64:
466                def fn(x, mode='lengths'):
467                    initial = 1
468                    # supply initial values to prevent gradcheck from failing for 0 length segments
469                    # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
470                    if reduce == 'min':
471                        initial = 1000
472                    elif reduce == 'max':
473                        initial = -1000
474                    segment_reduce_args = {x, reduce}
475                    segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
476                    if mode == 'lengths':
477                        segment_reduce_kwargs[mode] = lengths
478                    elif mode == 'offsets':
479                        segment_reduce_kwargs[mode] = indptr
480                    return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
481                self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
482                self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
483
484
485    @dtypes(
486        *product(
487            (torch.half, torch.bfloat16, torch.float, torch.double),
488            (torch.int, torch.int64),
489        )
490    )
491    def test_multi_d(self, device, dtypes):
492        val_dtype, length_type = dtypes
493        axis = 0
494        lengths = [0, 2, 3, 0]
495        data = np.arange(50).reshape(5, 2, 5).tolist()
496        expected_grad = []
497
498        # TODO: calculate grad and check correctness
499        check_backward = False
500
501        for reduction in reductions:
502            initial_value = 0
503            if reduction == "max":
504                expected_result = [
505                    np.full((2, 5), initial_value).tolist(),
506                    np.max(data[:2], axis=0).tolist(),
507                    np.max(data[2:], axis=0).tolist(),
508                    np.full((2, 5), initial_value).tolist(),
509                ]
510            elif reduction == "mean":
511                expected_result = [
512                    np.full((2, 5), initial_value).tolist(),
513                    np.mean(data[:2], axis=0).tolist(),
514                    np.mean(data[2:], axis=0).tolist(),
515                    np.full((2, 5), initial_value).tolist(),
516                ]
517            elif reduction == "min":
518                initial_value = 1000  # some high number
519                expected_result = [
520                    np.full((2, 5), initial_value).tolist(),
521                    np.min(data[:2], axis=0).tolist(),
522                    np.min(data[2:], axis=0).tolist(),
523                    np.full((2, 5), initial_value).tolist(),
524                ]
525            elif reduction == "sum":
526                expected_result = [
527                    np.full((2, 5), initial_value).tolist(),
528                    np.sum(data[:2], axis=0).tolist(),
529                    np.sum(data[2:], axis=0).tolist(),
530                    np.full((2, 5), initial_value).tolist(),
531                ]
532            elif reduction == "prod":
533                initial_value = 1
534                expected_result = [
535                    np.full((2, 5), initial_value).tolist(),
536                    np.prod(data[:2], axis=0).tolist(),
537                    np.prod(data[2:], axis=0).tolist(),
538                    np.full((2, 5), initial_value).tolist(),
539                ]
540            for unsafe in [True, False]:
541                self._test_common(
542                    reduction,
543                    device,
544                    val_dtype,
545                    unsafe,
546                    axis,
547                    initial_value,
548                    data,
549                    lengths,
550                    expected_result,
551                    expected_grad,
552                    check_backward,
553                )
554
555    @dtypes(torch.int, torch.int64)
556    def test_unsafe_flag(self, device, dtype):
557        length_type = dtype
558        lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
559        data = torch.arange(6, dtype=torch.float, device=device)
560
561        # test for error on 1-D lenghts
562        with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
563            torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
564
565        # test for error on multi-D lengths
566        nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
567        nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
568        with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
569            torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
570
571
572
573
574instantiate_device_type_tests(TestSegmentReductions, globals())
575
576if __name__ == "__main__":
577    run_tests()
578