xref: /aosp_15_r20/external/pytorch/test/nn/test_embedding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import itertools
3import random
4import unittest
5from itertools import product
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.testing._internal.common_cuda import TEST_CUDA
11from torch.testing._internal.common_device_type import (
12    dtypes,
13    dtypesIfCUDA,
14    instantiate_device_type_tests,
15    largeTensorTest,
16    onlyCUDA,
17    onlyNativeDeviceTypes,
18    skipCUDAIf,
19    skipMeta,
20    TEST_WITH_ROCM,
21)
22
23from torch.testing._internal.common_nn import NNTestCase
24from torch.testing._internal.common_utils import (
25    _assertGradAndGradgradChecks,
26    dtype2prec_DONTUSE,
27    instantiate_parametrized_tests,
28    IS_JETSON,
29    parametrize as parametrize_test,
30    run_tests,
31    set_default_dtype,
32    skipIfTorchDynamo,
33)
34
35
36class TestEmbeddingNN(NNTestCase):
37    _do_cuda_memory_leak_check = True
38    _do_cuda_non_default_stream = True
39
40    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
41    def test_embedding_max_norm_unsorted_repeating_indices(self):
42        def create_embedding(device):
43            # Seed RNG so we get the same Embedding each time
44            torch.manual_seed(0)
45            return torch.nn.Embedding(
46                num_embeddings=20, embedding_dim=64, max_norm=1.0
47            ).to(device)
48
49        ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000)
50        out_cpu = create_embedding("cpu")(ix)
51
52        ix = ix.to("cuda")
53        out = create_embedding("cuda")(ix)
54        self.assertEqual(out.cpu(), out_cpu)
55
56    def test_embedding_sparse_basic(self):
57        embedding = nn.Embedding(10, 20, sparse=True)
58        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)
59        embedding(input).sum().backward()
60        self.assertTrue(embedding.weight.grad.is_sparse)
61        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
62
63    def test_embedding_sparse_empty_tensor(self):
64        embedding = nn.Embedding(0, 0, sparse=True)
65        input = torch.tensor([], dtype=torch.int64)
66        embedding(input).sum().backward()
67        self.assertTrue(embedding.weight.grad.is_sparse)
68        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
69
70        embedding = nn.Embedding(10, 0, sparse=True)
71        input = torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])
72        embedding(input).sum().backward()
73        self.assertTrue(embedding.weight.grad.is_sparse)
74        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
75
76    def test_move_sparse_half_embedding(self):
77        embedding = nn.Embedding(10, 3, sparse=True)
78        self.assertEqual(embedding.weight.device.type, "cpu")
79        self.assertEqual(embedding.weight.dtype, torch.get_default_dtype())
80        embedding.to(torch.float16)
81        self.assertEqual(embedding.weight.dtype, torch.float16)
82        self.assertEqual(embedding.embedding_dim, 3)
83        self.assertEqual(embedding.num_embeddings, 10)
84
85        if torch.cuda.is_available():
86            embedding.to("cuda")
87            self.assertEqual(embedding.weight.device.type, "cuda")
88            embedding.to("cpu")
89            self.assertEqual(embedding.weight.device.type, "cpu")
90
91    def test_embedding_max_norm(self):
92        embedding = nn.Embedding(22, 5, max_norm=1.0)
93        input = torch.tensor([2, 8, 8, 6], dtype=torch.long)
94        output = embedding(input)
95        self.assertEqual(output[1], output[2])
96        self.assertTrue(output.data.norm(p=2, dim=1).le(1).all())
97
98    @parametrize_test(
99        "dtype",
100        (
101            torch.uint8,
102            torch.int8,
103            torch.int16,
104            torch.int32,
105            torch.int64,
106            torch.float,
107            torch.double,
108        ),
109    )
110    def test_embedding_from_pretrained(self, dtype):
111        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
112        embedding = nn.Embedding.from_pretrained(a)
113        self.assertEqual(a, embedding.weight.data)
114
115        input = torch.LongTensor([0, 1])
116        output = embedding(input)
117        self.assertEqual(a, output)
118
119    def test_embedding_bag_from_pretrained(self):
120        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
121        embedding = nn.EmbeddingBag.from_pretrained(a)
122        self.assertEqual(a, embedding.weight)
123
124        input = torch.tensor([0, 1], dtype=torch.long)
125        output = embedding(input, torch.arange(input.size(0)))
126        self.assertEqual(a, output)
127
128    def test_embedding_from_pretrained_padding_idx(self):
129        padding_idx = 2
130        padding_vec = torch.ones(3) * 7
131        embeddings = torch.rand(4, 3, requires_grad=True)
132        with torch.no_grad():
133            embeddings[padding_idx] = padding_vec
134        embedding_nn = nn.Embedding.from_pretrained(embeddings, padding_idx=padding_idx)
135        self.assertEqual(embedding_nn.weight[padding_idx], padding_vec)
136
137    def test_embedding_bag_from_pretrained_padding_idx(self):
138        padding_idx = 2
139        embeddings = torch.rand(4, 3, requires_grad=True)
140        embedding_nn = nn.EmbeddingBag.from_pretrained(
141            embeddings, padding_idx=padding_idx
142        )
143        self.assertEqual(embedding_nn.weight, embeddings)
144
145    def test_embedding_from_pretrained_options(self):
146        with set_default_dtype(torch.double):
147            a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
148            opts = {
149                "max_norm": 2.0,
150                "norm_type": 0.5,
151                "scale_grad_by_freq": False,
152                "sparse": True,
153            }
154            embedding = nn.Embedding.from_pretrained(a, **opts)
155            input = torch.LongTensor([0, 1])
156            output = embedding(input)
157            # test output and that weight matrix was renormalized
158            self.assertEqual(a, output)
159            self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
160            self.assertTrue(
161                output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()
162            )
163
164    def test_embedding_functional(self):
165        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
166        embeddings = torch.rand(4, 3, requires_grad=True)
167
168        embed_old = torch.nn.Embedding(4, 3)
169        embed_old.weight.data = embeddings.data
170        # A silly test for eager, this test is useful for when we run under PYTORCH_TEST_WITH_DYNAMO=1
171        # as it ensures that setattr correctly works.
172        self.assertEqual(embed_old.weight.data, embeddings.data)
173        res_old = embed_old(a)
174
175        res_F = F.embedding(a, embeddings)
176        self.assertEqual(res_old, res_F)
177
178        embed_old = torch.nn.Embedding(4, 3)
179        embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
180        res_old = embed_old(a)
181        res_F = F.embedding(a, embeddings, padding_idx=2)
182
183        self.assertEqual(res_old, res_F)
184
185    # https://github.com/pytorch/pytorch/issues/130806
186    @largeTensorTest("40GB", device="cuda")
187    def test_large_tensors(self):
188        input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
189        w = torch.randn([16032, 16384], device="cuda")
190        out = torch.nn.functional.embedding(input, w)
191        self.assertEqual(out.dim(), 2)
192        self.assertEqual(out.numel(), 2147483648)
193
194    def test_embedding_bag_functional(self):
195        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
196        embeddings = torch.rand(4, 3, requires_grad=True)
197
198        embed_old = torch.nn.EmbeddingBag(4, 3)
199        embed_old.weight = torch.nn.Parameter(embeddings)
200        res_old = embed_old(a)
201
202        res_F = F.embedding_bag(a, embeddings)
203        self.assertEqual(res_old, res_F)
204
205        embed_old = torch.nn.EmbeddingBag(4, 3)
206        embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
207        res_old = embed_old(a)
208        res_F = F.embedding_bag(a, embeddings, padding_idx=2)
209
210        self.assertEqual(res_old, res_F)
211
212    # Make sure that error is thrown if padding_idx is out of bounds
213    def test_embedding_bag_padding_idx_error(self):
214        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
215        num_embeddings = 4
216        num_features = 3
217        embeddings = torch.rand(num_embeddings, num_features, requires_grad=True)
218
219        functional_err_msg = r"padding_idx must be within the number of embeddings"
220        module_err_msg = r"padding_idx must be within num_embeddings"
221
222        for padding_idx in range(-(num_embeddings + 2), (num_embeddings + 2)):
223            if (padding_idx < -num_embeddings) or (padding_idx >= num_embeddings):
224                with self.assertRaisesRegex(RuntimeError, functional_err_msg):
225                    F.embedding_bag(a, embeddings, padding_idx=padding_idx)
226                with self.assertRaisesRegex(AssertionError, module_err_msg):
227                    torch.nn.EmbeddingBag(
228                        num_embeddings, num_features, padding_idx=padding_idx
229                    )
230            else:
231                F.embedding_bag(a, embeddings, padding_idx=padding_idx)
232                torch.nn.EmbeddingBag(
233                    num_embeddings, num_features, padding_idx=padding_idx
234                )
235
236    def test_embeddingbag_from_pretrained(self):
237        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
238        embeddingbag = nn.EmbeddingBag.from_pretrained(a)
239        self.assertEqual(a, embeddingbag.weight.data)
240
241        input = torch.LongTensor([[0, 1]])
242        output = embeddingbag(input)
243        self.assertEqual(a.mean(0, keepdim=True), output)
244
245    def test_embeddingbag_from_pretrained_options(self):
246        with set_default_dtype(torch.double):
247            a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
248            opts = {
249                "max_norm": 2.0,
250                "norm_type": 0.5,
251                "scale_grad_by_freq": False,
252                "mode": "max",
253                "sparse": False,
254            }
255            embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts)
256
257            input = torch.LongTensor([[0, 1]])
258            output = embeddingbag(input)
259            self.assertEqual(a.max(0, keepdim=True)[0], output)
260            self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
261            self.assertTrue(
262                a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()
263            )
264
265    def test_embeddingbag_include_last_offset(self):
266        # Test case from https://github.com/pytorch/pytorch/issues/89677
267        embeddingbag = nn.EmbeddingBag(100, 3, include_last_offset=True, padding_idx=61)
268        input = torch.tensor([0, 1, 2, 3])
269        out = embeddingbag(input, torch.tensor([0, 3, 3]))
270        out2 = embeddingbag(input, torch.tensor([0, 3, 4]))
271
272        weight = embeddingbag.weight
273        row0 = weight[0:3].mean(0)
274        row1 = weight[3]
275        ref_out = torch.stack([row0, row1])
276
277        self.assertEqual(ref_out, out)
278        self.assertEqual(ref_out, out2)
279
280
281class TestEmbeddingNNDeviceType(NNTestCase):
282    def test_embedding_dense_grad(self, device):
283        with set_default_dtype(torch.double):
284            embd = nn.Embedding(20, 20).to(device)
285            weight = embd.weight
286
287            def fn_wrapper(device):
288                def fn(weight):
289                    inp = torch.tensor(
290                        [[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long
291                    ).to(device)
292                    return torch.nn.functional.embedding(inp, weight)
293
294                return fn
295
296            fn = fn_wrapper(device)
297            _assertGradAndGradgradChecks(self, fn, (weight,))
298
299    def test_embedding_scalar_weight_error(self, device):
300        indices = torch.rand(2, 2, device=device).long()
301        weights = [
302            torch.tensor(1.0, device=device),
303            torch.tensor(1.0, device=device).reshape(1, 1, 1),
304        ]
305
306        for weight in weights:
307            with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"):
308                torch.nn.functional.embedding(indices, weight)
309
310    @dtypesIfCUDA(torch.float16, torch.float64)
311    @dtypes(torch.float64)
312    def test_embedding_backward(self, device, dtype):
313        embedding = nn.Embedding(10, 3, sparse=True)
314        tensor = torch.tensor([[7, 1, 3]])
315        ones = torch.tensor(1.0, dtype=dtype).expand(3, 3)
316        tensorTwice = tensor.repeat(1, 2)
317        onesTwice = torch.cat((ones, ones))
318
319        embedding = embedding.to(dtype=dtype).to(device)
320        tensor = tensor.to(device)
321        ones = ones.to(device)
322        tensorTwice = tensorTwice.to(device)
323        onesTwice = onesTwice.to(device)
324
325        embedding.zero_grad()
326        embedding(tensor[0]).sum().backward()
327        self.assertEqual(embedding.weight.grad._indices(), tensor)
328        self.assertEqual(embedding.weight.grad._values(), ones)
329
330        embedding.zero_grad()
331        embedding(tensor[0]).sum().backward()
332        embedding(tensor[0]).sum().backward()
333        self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
334        self.assertEqual(embedding.weight.grad._values(), onesTwice)
335
336        embedding.zero_grad()
337        embedding(tensor[0]).sum().backward()
338        tensor[0, 0] = 8
339        embedding(tensor[0]).sum().backward()
340        tensorTwice[0, 3] = 8
341        self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
342        self.assertEqual(embedding.weight.grad._values(), onesTwice)
343
344    @dtypesIfCUDA(
345        *(
346            (torch.float, torch.double, torch.bfloat16, torch.half)
347            if TEST_WITH_ROCM
348            else (torch.float, torch.double, torch.half)
349        )
350    )
351    @dtypes(torch.float32)
352    def test_embedding_max_norm_backward(self, device, dtype):
353        # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
354        weight = torch.randn((4, 4), device=device, dtype=dtype) * 2
355        weight.requires_grad_()
356        inp_list = [0, 1, 2, 2]
357        inp = torch.tensor(inp_list, device=device)
358        out = nn.functional.embedding(inp, weight, max_norm=1.0).sum()
359        out.backward()
360
361        expected_grad = (
362            torch.tensor([[1.0, 1.0, 2.0, 0.0]], device=device, dtype=dtype)
363            .transpose(0, 1)
364            .expand(4, 4)
365        )
366        self.assertEqual(weight.grad, expected_grad)
367
368    @dtypesIfCUDA(
369        *(
370            (torch.float, torch.double, torch.bfloat16, torch.half)
371            if TEST_WITH_ROCM
372            else (torch.float, torch.double, torch.half)
373        )
374    )
375    @dtypes(torch.float32)
376    def test_embedding_max_norm_fwd_AD(self, device, dtype):
377        if torch.device(device).type == "xla":
378            self.skipTest("forward AD doesn't work on xla")
379
380        # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
381        weight = torch.randn((4, 4), device=device, dtype=dtype) * 2
382        tangent = torch.ones((4, 4), device=device, dtype=dtype)
383        inp = torch.tensor([[0, 1], [2, 2]], device=device)
384        with torch.autograd.forward_ad.dual_level():
385            dual_weight = torch.autograd.forward_ad.make_dual(weight, tangent)
386            out = nn.functional.embedding(inp, dual_weight, max_norm=1.0)
387            jvp = torch.autograd.forward_ad.unpack_dual(out).tangent
388
389        expected_grad = torch.ones((2, 2, 4), device=device, dtype=dtype)
390        self.assertEqual(jvp, expected_grad)
391
392    @dtypesIfCUDA(
393        *(
394            (torch.float, torch.double, torch.bfloat16, torch.half)
395            if TEST_WITH_ROCM
396            else (torch.float, torch.double, torch.half)
397        )
398    )
399    @dtypes(torch.float32)
400    def test_embedding_padding_idx(self, device, dtype):
401        embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype)
402        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
403        output = embedding(input)
404        self.assertEqual(output[0][0].sum(), 0)
405        self.assertEqual(output[1][2].sum(), 0)
406
407        embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True).to(device, dtype)
408        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
409        output = embedding(input)
410        self.assertEqual(output[0][0].sum(), 0)
411        self.assertEqual(output[1][2].sum(), 0)
412
413        # negative indexing check for padding_idx
414        # padding_idx=-2, num_embeddings=10 ==> index 8 padded
415        embedding = nn.Embedding(10, 20, padding_idx=-2).to(device, dtype)
416        input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
417        output = embedding(input)
418        self.assertEqual(output[0][2].sum(), 0)
419        self.assertEqual(output[1][1].sum(), 0)
420
421        embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True).to(device, dtype)
422        input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
423        output = embedding(input)
424        self.assertEqual(output[0][2].sum(), 0)
425        self.assertEqual(output[1][1].sum(), 0)
426
427        # change padding vector
428        padding_vector = torch.ones(20, dtype=dtype, device=device)
429        embedding = nn.Embedding(10, 20, padding_idx=2, sparse=True).to(device, dtype)
430        with torch.no_grad():
431            embedding.weight[2] = padding_vector
432        input = torch.tensor([0, 2], dtype=torch.long).to(device)
433        output = embedding(input)
434        self.assertEqual(output[1], padding_vector)
435
436        # out of bounds check for padding_idx
437        self.assertRaises(
438            AssertionError,
439            nn.Embedding,
440            num_embeddings=10,
441            embedding_dim=20,
442            padding_idx=25,
443        )
444        self.assertRaises(
445            AssertionError,
446            nn.Embedding,
447            num_embeddings=10,
448            embedding_dim=20,
449            padding_idx=-25,
450        )
451
452        padding_idx = 0
453        embedding = nn.Embedding(5, 2, padding_idx=padding_idx).to(device, dtype)
454        for n in (
455            1,
456            2,
457            1000,
458        ):  # Need large N to trigger all the methods we have implemented
459            for other_indices in ([], [1, 3], [2]):
460                indices = torch.tensor(
461                    other_indices + [padding_idx] * n, dtype=torch.long
462                ).to(device)
463                pre = embedding.weight[padding_idx].clone()
464                embedding(indices).sum().backward()
465                after = (embedding.weight + embedding.weight.grad)[padding_idx]
466                embedding.zero_grad()
467                self.assertEqual(after, pre)
468
469                # test double backward
470                emb_sum = embedding(indices).sum()
471                emb_grad = torch.autograd.grad(
472                    outputs=emb_sum,
473                    inputs=list(embedding.parameters()),
474                    retain_graph=True,
475                )
476                scalar = emb_grad[0].sum() + emb_sum
477                scalar.backward()
478                after = (embedding.weight + embedding.weight.grad)[padding_idx]
479                embedding.zero_grad()
480                self.assertEqual(after, pre)
481
482    # Check correctness of torch.nn.functional.embedding_bag forward and
483    # backward functions with padding_idx, given a 1D input separated into bags
484    # with an offset array. Compare against an equivalent 2D input that uses
485    # padding indices to fill in the gaps indicated by the offset array
486
487    @skipIfTorchDynamo("see https://github.com/pytorch/pytorch/pull/95621")
488    @onlyNativeDeviceTypes
489    @dtypes(torch.float32, torch.float64)
490    @dtypesIfCUDA(torch.half, torch.bfloat16)
491    def test_embedding_bag_1D_padding_idx(self, device, dtype):
492        num_features = 3
493        max_indices_per_bag = 10
494        num_bags = 10
495        num_words = 100
496
497        def gen_1D_indices_offsets(include_last_offset, allpad):
498            indices = []
499            offsets = []
500            cur_offset = 0
501
502            # Make one bag full and one bag empty, for extra coverage
503            empty_bag = random.randint(0, num_bags - 1)
504            full_bag = empty_bag
505            while full_bag == empty_bag:
506                full_bag = random.randint(0, num_bags - 1)
507
508            for bag in range(num_bags):
509                offsets.append(cur_offset)
510                if bag == full_bag:
511                    bag_size = max_indices_per_bag
512                elif bag == empty_bag:
513                    bag_size = 0
514                else:
515                    bag_size = random.randint(1, max_indices_per_bag - 1)
516                indices += [
517                    1 if allpad else random.randint(0, num_words - 1)
518                    for _ in range(bag_size)
519                ]
520                cur_offset += bag_size
521
522            # embedding_bag requires first entry of offsets to be 0
523            assert offsets[0] == 0
524
525            indices = torch.tensor(indices, device=device)
526
527            if include_last_offset:
528                offsets.append(indices.size(0))
529
530            offsets = torch.tensor(offsets, device=device)
531
532            return indices, offsets
533
534        # Convert a 1-D indices-offsets representation into 2-D. Fill any empty
535        # indices with padding_idx
536        def gen_2D_indices_from_1D(
537            indices_1D, offsets, include_last_offset, padding_idx
538        ):
539            assert offsets[0] == 0
540            if include_last_offset:
541                offsets = offsets[:-1]
542            indices_2D = torch.empty(
543                num_bags, max_indices_per_bag, device=device, dtype=torch.long
544            )
545            for bag in range(num_bags):
546                # Determine the start and end position of the bag within indices_1D
547                start = offsets[bag]
548                end = len(indices_1D) if bag + 1 == num_bags else offsets[bag + 1]
549                end = min(len(indices_1D), end)
550
551                # Pull out the bag's indices from indices_1D, and fill any
552                # remaining space with padding indices
553                indices_in_bag = []
554                for item_pos in range(0, max_indices_per_bag):
555                    if (start + item_pos) < end:
556                        indices_in_bag.append(indices_1D[start + item_pos])
557                    else:
558                        indices_in_bag.append(padding_idx)
559                indices_2D[bag] = torch.tensor(indices_in_bag, device=device)
560
561            return indices_2D
562
563        test_cases = product(
564            ["max", "mean", "sum"], [False, True], [False, True], [False, True]
565        )
566
567        for mode, sparse, include_last_offset, allpad in test_cases:
568            # Max sparse and bfloat16 are not supported
569            if mode == "max":
570                if sparse or (dtype == torch.bfloat16):
571                    continue
572            indices_1D, offsets = gen_1D_indices_offsets(include_last_offset, allpad)
573            for padding_idx_1D in list(set(indices_1D.tolist())) + [None]:
574                msg = (
575                    f"mode: '{mode}', sparse: {sparse}, include_last_offset: {include_last_offset}, "
576                    f"padding_idx_1D: {padding_idx_1D}"
577                )
578
579                # If 1D input does not use a padding index, we still need one for the 2D input,
580                # so we can add one dummy word to the weights to act as the padded word
581                padding_idx_2D = (
582                    padding_idx_1D if padding_idx_1D is not None else num_words
583                )
584                num_words_with_padding = (
585                    num_words if padding_idx_1D is not None else num_words + 1
586                )
587
588                indices_2D = gen_2D_indices_from_1D(
589                    indices_1D, offsets, include_last_offset, padding_idx_2D
590                )
591
592                weights = torch.randn(
593                    num_words_with_padding,
594                    num_features,
595                    dtype=dtype,
596                    device=device,
597                    requires_grad=True,
598                )
599                weights_check = weights.clone().detach().requires_grad_(True)
600
601                bag = torch.nn.functional.embedding_bag(
602                    indices_1D,
603                    weights,
604                    offsets,
605                    padding_idx=padding_idx_1D,
606                    mode=mode,
607                    sparse=sparse,
608                    include_last_offset=include_last_offset,
609                )
610
611                bag_check = torch.nn.functional.embedding_bag(
612                    indices_2D,
613                    weights_check,
614                    padding_idx=padding_idx_2D,
615                    mode=mode,
616                    sparse=sparse,
617                )
618                self.assertEqual(bag, bag_check, msg=msg)
619
620                bag.sum().backward()
621                bag_check.sum().backward()
622
623                # Sometimes, half dtype gradients mismatch by a greater amount
624                # than other dtypes
625                if dtype in [torch.half, torch.bfloat16]:
626                    atol = 0.01
627                    rtol = 0.01
628                else:
629                    atol = None
630                    rtol = None
631                self.assertEqual(
632                    weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol
633                )
634
635    # Check correctness of torch.nn.functional.embedding_bag forward and
636    # backward functions with padding_idx, given a 2D indices input. Compare
637    # against torch.nn.functional.embedding followed by a reduction.
638    @onlyNativeDeviceTypes
639    @dtypes(torch.float32, torch.float64)
640    @dtypesIfCUDA(torch.half, torch.bfloat16)
641    def test_embedding_bag_2D_padding_idx(self, device, dtype):
642        # Use a Python implementation of embedding_bag with padding_idx support
643        # to check torch.nn.functional.embedding_bag correctness
644        def embedding_bag_check(indices, weights, mode, sparse, padding_idx):
645            assert padding_idx is not None
646            embedding = torch.nn.functional.embedding(
647                indices, weights, padding_idx=padding_idx, sparse=sparse
648            )
649
650            reduction_dim = indices.dim() - 1
651
652            if mode == "sum" or mode == "mean":
653                # We must avoid including elements at padding_idx in the
654                # sum/mean, so multiply those elements by 0, and multiply
655                # all other elements by 1
656                per_sample_weights = indices.ne(padding_idx).to(dtype).unsqueeze(-1)
657                res = embedding.mul(per_sample_weights).sum(dim=reduction_dim)
658
659                if mode == "mean":
660                    weights_sum = per_sample_weights.sum(dim=reduction_dim)
661                    res = res.div(weights_sum)
662
663            elif mode == "max":
664                # We must avoid allowing elements at padding_idx to be chosen
665                # as the max, so set those elements to negative infinity
666                res = embedding.masked_fill(
667                    indices.unsqueeze(-1) == padding_idx, -float("inf")
668                ).amax(dim=reduction_dim)
669
670            else:
671                raise RuntimeError(f"mode '{mode}' is not available")
672
673            # If a row is all padding, set its corresponding result row to 0.
674            # This is needed because the above mean and max mode
675            # implementations set these elements to nan and -inf, respectively
676            if mode in ["mean", "max"]:
677                res = res.masked_fill(
678                    indices.eq(padding_idx).all(dim=-1).unsqueeze(-1), 0
679                )
680
681            return res
682
683        num_features = 3
684        num_words = 10
685        indices_dim1 = 10
686
687        for mode, sparse, allpad, indices_dim0 in product(
688            ["max", "mean", "sum"], [False, True], [False, True], [1, 10]
689        ):
690            # Max sparse and bfloat16 are not supported
691            if mode == "max":
692                if sparse or (dtype == torch.bfloat16):
693                    continue
694
695            if allpad:
696                indices = torch.empty(
697                    indices_dim0, indices_dim1, dtype=torch.long, device=device
698                ).fill_(1)
699            else:
700                indices = torch.randint(
701                    0, num_words, (indices_dim0, indices_dim1), device=device
702                )
703
704                if indices_dim0 > 1:
705                    # Fill one row with duplicate index so we can test with a fully
706                    # padded row
707                    duplicate_row = random.randint(0, indices_dim0 - 1)
708                    indices[duplicate_row] = indices[duplicate_row][0]
709
710            for padding_idx in list(set(indices.flatten(0, -1).tolist())):
711                weights = torch.randn(
712                    num_words,
713                    num_features,
714                    dtype=dtype,
715                    device=device,
716                    requires_grad=True,
717                )
718                weights_check = weights.clone().detach().requires_grad_(True)
719
720                msg = (
721                    f"mode: '{mode}', sparse: {sparse}, padding_idx: {padding_idx}, "
722                    f"allpad: {allpad}, indices.size(): {indices.size()}"
723                )
724
725                # Check forward with a Python implementation of padding_idx embedding_bag
726                bag_check = embedding_bag_check(
727                    indices, weights_check, mode, sparse, padding_idx
728                )
729                bag = torch.nn.functional.embedding_bag(
730                    indices, weights, padding_idx=padding_idx, mode=mode, sparse=sparse
731                )
732
733                self.assertEqual(bag, bag_check, msg=msg)
734
735                bag_check.sum().backward()
736                grad_check = weights_check.grad
737
738                bag.sum().backward()
739                grad = weights.grad
740
741                # Sometimes, half dtype gradients mismatch by a greater amount
742                # than other dtypes
743                if dtype in [torch.half, torch.bfloat16]:
744                    atol = 0.01
745                    rtol = 0.01
746                else:
747                    atol = None
748                    rtol = None
749                self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol)
750
751    @onlyCUDA
752    @dtypes(
753        *(
754            (torch.float, torch.double, torch.bfloat16, torch.half)
755            if TEST_WITH_ROCM
756            else (torch.float, torch.double, torch.half)
757        )
758    )
759    def test_embedding_max_norm_device(self, device, dtype):
760        embedding = nn.Embedding(22, 5, max_norm=1.0).to(device, dtype=dtype)
761        # nn.Embedding only takes LongTensor as input
762        input = torch.tensor([2, 8, 8, 6], device=device, dtype=torch.long)
763        output = embedding(input)
764        self.assertEqual(output[1], output[2])
765        self.assertTrue(output.data.norm(p=2, dim=1).le(1).all())
766
767    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
768    def test_embedding_bag_empty_input(self, device, dtypes):
769        m = 4
770        n = 3
771        x = torch.tensor([], device=device, dtype=dtypes[0])
772        for sparse in [True, False]:
773            Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
774            Embed.to(device)
775
776            output = Embed(
777                input=x, offsets=torch.tensor([0], device=device, dtype=dtypes[1])
778            )
779            self.assertEqual(output, torch.zeros_like(output))
780
781            output = Embed(
782                input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtypes[1])
783            )
784            self.assertEqual(output, torch.zeros_like(output))
785
786    @skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.")
787    @dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long)))
788    @parametrize_test("padding_idx", [None, 0])
789    @parametrize_test("mode", ["sum", "mean", "max"])
790    def test_embedding_bag_out_of_bounds_idx(self, device, dtypes, padding_idx, mode):
791        padding_idx = 0
792        w_dtype, idx_dtype = dtypes
793        # negative out-of-bound
794        idx1 = torch.tensor([[-1, 1]], device=device, dtype=idx_dtype)
795        # positive out-of-bound
796        idx2 = torch.tensor([[11, 8]], device=device, dtype=idx_dtype)
797        weight = torch.randn(10, 2, device=device, dtype=w_dtype)
798        if mode == "sum":
799            # Only `sum` supports per_sample_weight
800            per_sample_weights = (
801                None,
802                torch.randn_like(idx1, device=device, dtype=w_dtype),
803            )
804        else:
805            per_sample_weights = (None,)
806
807        for p_s_weights, idx in itertools.product(per_sample_weights, (idx1, idx2)):
808            msg = "Expected idx >= 0 && idx < num_embeddings"
809            with self.assertRaisesRegex(RuntimeError, msg):
810                torch.nn.functional.embedding_bag(
811                    idx,
812                    weight,
813                    per_sample_weights=p_s_weights,
814                    padding_idx=padding_idx,
815                    mode=mode,
816                )
817
818    def test_embedding_bag_dimension_errors(self, device):
819        funcs = (
820            lambda x, y, z: torch.nn.functional.embedding_bag(y, x, z),
821            torch.embedding_bag,
822            torch._embedding_bag,
823            torch._embedding_bag_forward_only,
824        )
825        for i, f in enumerate(funcs):
826            err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError
827
828            weight = torch.full(
829                (
830                    2,
831                    6,
832                ),
833                0,
834                dtype=torch.float64,
835                device=device,
836            )
837            indices = torch.full(
838                (
839                    2,
840                    0,
841                    0,
842                    6,
843                    6,
844                ),
845                2,
846                dtype=torch.int64,
847                device=device,
848            )
849            offsets = torch.full((2, 0, 0, 6, 6), 0, dtype=torch.int64, device=device)
850
851            if i == 0:
852                error_msg = "input has to be 1D or 2D Tensor"
853            else:
854                error_msg = "input has to be a 1D or 2D Tensor"
855            torch._dynamo.disable(self.assertRaisesRegex)(
856                err_type, error_msg, lambda: f(weight, indices, offsets)
857            )
858
859            weight = torch.full((2, 2), 0, dtype=torch.float64, device=device)
860            indices = torch.full((2,), 1, dtype=torch.int64, device=device)
861
862            torch._dynamo.disable(self.assertRaisesRegex)(
863                err_type,
864                "offsets has to be a 1D Tensor",
865                lambda: f(weight, indices, offsets),
866            )
867
868            weight = torch.full((2, 2, 2), 0, dtype=torch.float64, device=device)
869            indices = torch.full((2,), 2, dtype=torch.int64, device=device)
870            offsets = torch.full((2,), 0, dtype=torch.int64, device=device)
871
872            torch._dynamo.disable(self.assertRaisesRegex)(
873                err_type,
874                "weight has to be a 2D Tensor",
875                lambda: f(weight, indices, offsets),
876            )
877
878    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
879    def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes):
880        # Failure 1: mismatched embeddings / per_sample_weights dtype
881        es = nn.EmbeddingBag(5, 2, mode="sum").to(dtype=torch.float, device=device)
882        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
883        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
884        per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
885        if device == "cpu":
886            with self.assertRaisesRegex(RuntimeError, "have the same type as"):
887                es(input, offsets, per_sample_weights)
888        else:
889            with self.assertRaisesRegex(RuntimeError, "expected scalar type"):
890                es(input, offsets, per_sample_weights)
891
892        # Failure 2.1: input/per_sample_weights have different sizes (1d input)
893        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
894        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
895        per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
896        with self.assertRaisesRegex(ValueError, "same shape as the input"):
897            es(input, offsets, per_sample_weights)
898
899        # Failure 2.2: input/per_sample_weights have different sizes (2d input)
900        input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
901        offsets = None
902        per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
903        with self.assertRaisesRegex(ValueError, "same shape as the input"):
904            es(input, offsets, per_sample_weights)
905
906        # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean')
907        for unsupported_mode in ("max", "mean"):
908            es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
909                dtype=torch.float, device=device
910            )
911            input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
912            offsets = None
913            per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
914            with self.assertRaisesRegex(
915                NotImplementedError, "only supported for mode='sum'"
916            ):
917                es(input, offsets, per_sample_weights)
918
919    def _embedding_bag_reference_impl(
920        self,
921        input,
922        weight,
923        offsets=None,
924        mode="sum",
925        per_sample_weights=None,
926        include_last_offset=False,
927    ):
928        assert mode == "sum" or per_sample_weights is None
929        assert offsets is not None
930        if per_sample_weights is None:
931            per_sample_weights = torch.ones(input.size()).to(
932                dtype=weight.dtype, device=weight.device
933            )
934        assert input.numel() == per_sample_weights.numel()
935
936        bags = []
937        long_input = input.to(torch.long)
938        embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(
939            1
940        )
941        if include_last_offset:
942            for index in range(len(offsets) - 1):
943                offset = offsets[index]
944                next_offset = offsets[index + 1]
945                length = next_offset - offset
946                if length == 0:
947                    bags.append(
948                        torch.tensor([0] * weight.size(1)).to(
949                            dtype=embeddings.dtype, device=embeddings.device
950                        )
951                    )
952                else:
953                    if mode == "sum":
954                        bags.append(embeddings.narrow(0, offset, length).sum(0))
955                    elif mode == "mean":
956                        bags.append(
957                            embeddings.narrow(0, offset, length).sum(0).div(length)
958                        )
959                    else:
960                        assert mode == "max"
961                        bags.append(embeddings.narrow(0, offset, length).max(0)[0])
962        else:
963            for index, offset in enumerate(offsets):
964                if index + 1 < len(offsets):
965                    next_offset = offsets[index + 1]
966                else:
967                    next_offset = len(long_input)
968                length = next_offset - offset
969                if length == 0:
970                    bags.append(
971                        torch.tensor([0] * weight.size(1)).to(
972                            dtype=embeddings.dtype, device=embeddings.device
973                        )
974                    )
975                else:
976                    if mode == "sum":
977                        bags.append(embeddings.narrow(0, offset, length).sum(0))
978                    elif mode == "mean":
979                        bags.append(
980                            embeddings.narrow(0, offset, length).sum(0).div(length)
981                        )
982                    else:
983                        assert mode == "max"
984                        bags.append(embeddings.narrow(0, offset, length).max(0)[0])
985        return torch.stack(bags)
986
987    @skipMeta
988    @dtypes(
989        *itertools.product(
990            (torch.int, torch.long),
991            (torch.int, torch.long),
992            (torch.half, torch.bfloat16, torch.float, torch.double),
993        )
994    )
995    @dtypesIfCUDA(
996        *itertools.product(
997            (torch.int, torch.long),
998            (torch.int, torch.long),
999            (torch.float, torch.double, torch.half),
1000        )
1001    )
1002    def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
1003        # Test empty input and per sample weight, and backward pass. There was a CUDA
1004        # invalid configuration bug (more context in #46572)
1005        def test_per_sample_weights(mode, trainable_scale):
1006            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
1007            es.weight.data.copy_(
1008                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
1009            )
1010            input = torch.tensor([], device=device, dtype=dtypes[0])
1011            offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[1])
1012            per_sample_weights = torch.randn_like(
1013                input, dtype=dtypes[2]
1014            ).requires_grad_(trainable_scale)
1015            ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
1016                trainable_scale
1017            )
1018            reference_weights = es.weight.detach().requires_grad_()
1019
1020            expected = self._embedding_bag_reference_impl(
1021                input, reference_weights, offsets, mode, ref_per_sample_weights
1022            )
1023            result = es(input, offsets, per_sample_weights)
1024            self.assertEqual(
1025                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
1026            )
1027
1028            grad = torch.randn_like(expected)
1029            result.backward(grad)
1030            # the reference impl doesn't have grad fn for empty input; but the grad should
1031            # simply be a zero tensor
1032            ref_weights_grad = torch.zeros_like(es.weight)
1033            self.assertEqual(
1034                es.weight.grad,
1035                ref_weights_grad,
1036                atol=dtype2prec_DONTUSE[dtypes[2]],
1037                rtol=0,
1038            )
1039            if trainable_scale:
1040                ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
1041                self.assertEqual(
1042                    per_sample_weights.grad,
1043                    ref_per_sample_weights_grad,
1044                    atol=dtype2prec_DONTUSE[dtypes[2]],
1045                    rtol=0,
1046                )
1047
1048        modes = ("sum",)
1049        trainable_scale = (True, False)
1050        for mode, trainable in itertools.product(modes, trainable_scale):
1051            test_per_sample_weights(mode, trainable)
1052
1053    @skipMeta
1054    @dtypes(
1055        *itertools.product(
1056            (torch.int, torch.long),
1057            (torch.int, torch.long),
1058            (torch.float, torch.double, torch.half, torch.bfloat16),
1059        )
1060    )
1061    @dtypesIfCUDA(
1062        *itertools.product(
1063            (torch.int, torch.long),
1064            (torch.int, torch.long),
1065            (torch.float, torch.double, torch.half),
1066        )
1067    )
1068    def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
1069        def test_per_sample_weights(mode, trainable_scale):
1070            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
1071            es.weight.data.copy_(
1072                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
1073            )
1074            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
1075            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
1076            per_sample_weights = torch.randn_like(
1077                input, dtype=dtypes[2]
1078            ).requires_grad_(trainable_scale)
1079            ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
1080                trainable_scale
1081            )
1082            reference_weights = es.weight.detach().requires_grad_()
1083
1084            expected = self._embedding_bag_reference_impl(
1085                input, reference_weights, offsets, mode, ref_per_sample_weights
1086            )
1087            result = es(input, offsets, per_sample_weights)
1088            self.assertEqual(
1089                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
1090            )
1091
1092            grad = torch.randn_like(expected).to(dtype=dtypes[2], device=device)
1093            result.backward(grad)
1094            expected.backward(grad)
1095            self.assertEqual(
1096                es.weight.grad,
1097                reference_weights.grad,
1098                atol=dtype2prec_DONTUSE[dtypes[2]],
1099                rtol=0,
1100            )
1101            if trainable_scale:
1102                self.assertEqual(
1103                    per_sample_weights.grad,
1104                    ref_per_sample_weights.grad,
1105                    atol=dtype2prec_DONTUSE[dtypes[2]],
1106                    rtol=0,
1107                )
1108
1109        modes = ("sum",)
1110        trainable_scale = (True, False)
1111        for mode, trainable in itertools.product(modes, trainable_scale):
1112            test_per_sample_weights(mode, trainable)
1113
1114    @skipMeta
1115    @dtypes(
1116        *itertools.product(
1117            (torch.int, torch.long),
1118            (torch.int, torch.long),
1119            (torch.float, torch.double, torch.half, torch.bfloat16),
1120        )
1121    )
1122    @dtypesIfCUDA(
1123        *itertools.product(
1124            (torch.int, torch.long),
1125            (torch.int, torch.long),
1126            (torch.float, torch.double, torch.half),
1127        )
1128    )
1129    def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
1130        def test_per_sample_weights_new_offsets(
1131            mode, trainable_scale, include_last_offset, has_weight=True
1132        ):
1133            es = nn.EmbeddingBag(
1134                5, 2, mode=mode, include_last_offset=include_last_offset
1135            ).to(dtype=dtypes[2], device=device)
1136            es.weight.data.copy_(
1137                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
1138            )
1139            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
1140            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
1141
1142            if include_last_offset:
1143                offsets = torch.cat(
1144                    (
1145                        offsets,
1146                        torch.tensor([input.size(0)], device=device, dtype=dtypes[1]),
1147                    ),
1148                    0,
1149                )
1150
1151            if has_weight:
1152                per_sample_weights = torch.randn_like(
1153                    input, device=device, dtype=dtypes[2]
1154                ).requires_grad_(trainable_scale)
1155                ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
1156                    trainable_scale
1157                )
1158            else:
1159                per_sample_weights = None
1160                ref_per_sample_weights = None
1161
1162            reference_weights = es.weight.detach().requires_grad_()
1163
1164            expected = self._embedding_bag_reference_impl(
1165                input,
1166                reference_weights,
1167                offsets,
1168                mode,
1169                ref_per_sample_weights,
1170                include_last_offset,
1171            )
1172            result = es(input, offsets, per_sample_weights)
1173            self.assertEqual(
1174                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
1175            )
1176
1177            grad = torch.randn_like(expected)
1178            result.backward(grad)
1179            expected.backward(grad)
1180            self.assertEqual(
1181                es.weight.grad,
1182                reference_weights.grad,
1183                atol=dtype2prec_DONTUSE[dtypes[2]],
1184                rtol=0,
1185            )
1186            if has_weight and trainable_scale:
1187                self.assertEqual(
1188                    per_sample_weights.grad,
1189                    ref_per_sample_weights.grad,
1190                    atol=dtype2prec_DONTUSE[dtypes[2]],
1191                    rtol=0,
1192                )
1193
1194        trainable_scale = (True, False)
1195        include_last_offset_list = (True, False)
1196        modes = (("sum", False), ("sum", True), ("max", False), ("mean", False))
1197        for (mode, has_weight), trainable, include_last_offset in itertools.product(
1198            modes, trainable_scale, include_last_offset_list
1199        ):
1200            test_per_sample_weights_new_offsets(
1201                mode, trainable, include_last_offset, has_weight
1202            )
1203
1204    def _test_EmbeddingBag_vs_Embedding(
1205        self,
1206        N,
1207        D,
1208        B,
1209        L,
1210        max_norm=None,
1211        mode="mean",
1212        device="cpu",
1213        wdtype=torch.float,
1214        dtype=torch.long,
1215        test_per_sample_weights=False,
1216        trainable_per_sample_weights=False,
1217        sparse=False,
1218        test_backward=True,
1219        backward_prec=None,
1220    ):
1221        es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(
1222            device, wdtype
1223        )
1224        e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype)
1225        e.weight.data.copy_(es.weight)
1226        input = torch.randint(N, (B, L), device=device, dtype=dtype)
1227        offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L)
1228        grad_output = torch.rand(B, D, device=device, dtype=wdtype)
1229
1230        if test_per_sample_weights:
1231            # To prevent large gradients, weights should sum to 1 for each bag
1232            per_sample_weights = torch.randn(B, L, device=device, dtype=wdtype).softmax(
1233                dim=-1
1234            )
1235            per_sample_weights_reference = per_sample_weights.clone().requires_grad_(
1236                trainable_per_sample_weights
1237            )
1238            per_sample_weights.requires_grad_(trainable_per_sample_weights)
1239            output = es(input.view(-1), offsets, per_sample_weights.view(-1))
1240        else:
1241            output = es(input.view(-1), offsets)
1242            per_sample_weights = None
1243            per_sample_weights_reference = None
1244
1245        if mode == "sum":
1246            if test_per_sample_weights:
1247                ref_output = (
1248                    e(input) * per_sample_weights_reference.unsqueeze(-1)
1249                ).sum(1)
1250            else:
1251                ref_output = e(input).sum(1)
1252        elif mode == "mean":
1253            assert not test_per_sample_weights
1254            ref_output = e(input).mean(1)
1255        elif mode == "max":
1256            assert not test_per_sample_weights
1257            ref_output = e(input).max(1)[0]
1258
1259        self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
1260
1261        if not test_backward:
1262            return
1263
1264        output.backward(grad_output)
1265        ref_output.backward(grad_output)
1266        es_weight_grad = es.weight.grad
1267        if sparse:
1268            es_weight_grad = es.weight.grad.to_dense()
1269
1270        # We have more floating point error here because we are dealing with larger numbers
1271        if backward_prec is None:
1272            needed_prec = dtype2prec_DONTUSE[wdtype] * 5
1273            rtol = 0.02 if wdtype == torch.half else 0
1274        else:
1275            needed_prec = backward_prec
1276            rtol = 0
1277
1278        self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=rtol)
1279
1280        if test_per_sample_weights and trainable_per_sample_weights:
1281            self.assertEqual(
1282                per_sample_weights.grad,
1283                per_sample_weights_reference.grad,
1284                atol=dtype2prec_DONTUSE[wdtype],
1285                rtol=0,
1286            )
1287
1288    @dtypesIfCUDA(
1289        *itertools.product(
1290            (torch.int, torch.long), (torch.half, torch.float, torch.double)
1291        )
1292    )
1293    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
1294    def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
1295        def run_tests(mode, sparse, trainable_per_sample_weights):
1296            kwargs = dict(
1297                test_per_sample_weights=True,
1298                device=device,
1299                mode=mode,
1300                wdtype=dtypes[1],
1301                dtype=dtypes[0],
1302                sparse=sparse,
1303                trainable_per_sample_weights=trainable_per_sample_weights,
1304            )
1305
1306            # Simple case
1307            self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs)
1308
1309            # B * L > 1000
1310            self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs)
1311
1312            # Large num_embedding
1313            self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs)
1314
1315            # Large embedding_dim
1316            self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
1317
1318        modes = ("sum",)
1319        sparsity = (True, False)
1320        trainable_scale = (True, False)
1321        for mode, sparse, trainable_per_sample_weights in itertools.product(
1322            modes, sparsity, trainable_scale
1323        ):
1324            run_tests(mode, sparse, trainable_per_sample_weights)
1325
1326        # Test CUDA Dense on half precision
1327        if device == "cuda":
1328            modes = ("sum",)
1329            sparsity = (False,)
1330            trainable_scale = (True, False)
1331            for mode, sparse, trainable_per_sample_weights in itertools.product(
1332                modes, sparsity, trainable_scale
1333            ):
1334                run_tests(mode, sparse, trainable_per_sample_weights)
1335
1336    def _test_EmbeddingBag(
1337        self,
1338        device,
1339        mode,
1340        sparse,
1341        wdtype=torch.double,
1342        dtype=torch.long,
1343        odtype=torch.long,
1344        test_backward=True,
1345    ):
1346        # check a known test example
1347        es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype)
1348        es.weight.data.copy_(
1349            torch.arange(1, 11, device=device).view_as(es.weight).to(wdtype)
1350        )
1351        input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype)
1352        offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=odtype)
1353
1354        grad_output = torch.tensor([1, 2, 3, 4], device=device, dtype=wdtype).view(2, 2)
1355        grad_output_with_empty = torch.tensor(
1356            [99, 99, 1, 2, 99, 99, 3, 4, 99, 99], device=device, dtype=wdtype
1357        ).view(5, 2)
1358
1359        if mode == "sum" or mode == "mean":
1360            denominator = 1 if mode == "sum" else 3
1361            expected_output = (
1362                torch.tensor([[13, 16], [13, 16]], device=device, dtype=wdtype)
1363                / denominator
1364            )
1365
1366            expected_output_with_empty = (
1367                torch.tensor(
1368                    [[0, 0], [13, 16], [0, 0], [13, 16], [0, 0]],
1369                    device=device,
1370                    dtype=wdtype,
1371                )
1372                / denominator
1373            )
1374
1375            expected_grad_weight = (
1376                torch.tensor(
1377                    [[3, 4], [5, 8], [0, 0], [1, 2], [3, 4]],
1378                    device=device,
1379                    dtype=wdtype,
1380                )
1381                / denominator
1382            )
1383        elif mode == "max":
1384            expected_output = torch.tensor(
1385                [[7, 8], [9, 10]], device=device, dtype=wdtype
1386            )
1387
1388            expected_output_with_empty = torch.tensor(
1389                [[0, 0], [7, 8], [0, 0], [9, 10], [0, 0]], device=device, dtype=wdtype
1390            )
1391
1392            expected_grad_weight = torch.tensor(
1393                [[0, 0], [0, 0], [0, 0], [1, 2], [3, 4]], device=device, dtype=wdtype
1394            )
1395        output = es(input, offsets)
1396        output.backward(grad_output_with_empty)
1397
1398        es_weight_grad = es.weight.grad
1399        if sparse:
1400            es_weight_grad = es.weight.grad.to_dense()
1401        self.assertEqual(output, expected_output_with_empty)
1402        self.assertEqual(
1403            es_weight_grad,
1404            expected_grad_weight,
1405            atol=dtype2prec_DONTUSE[wdtype],
1406            rtol=0,
1407        )
1408
1409        # check same example except as 2D (2 x 3)
1410        input = input.view(2, -1)
1411        es.zero_grad()
1412        output = es(input)
1413        output.backward(grad_output)
1414
1415        es_weight_grad = es.weight.grad
1416        if sparse:
1417            es_weight_grad = es.weight.grad.to_dense()
1418        self.assertEqual(output, expected_output)
1419        self.assertEqual(
1420            es_weight_grad,
1421            expected_grad_weight,
1422            atol=dtype2prec_DONTUSE[wdtype],
1423            rtol=0,
1424        )
1425
1426        # test all empty bags
1427        es.zero_grad()
1428        inputs = torch.tensor([], dtype=dtype, device=device)
1429        offsets = torch.tensor([0, 0, 0, 0], dtype=odtype, device=device)
1430        es(inputs, offsets).sum().backward()
1431        dense_grad = es.weight.grad
1432        if dense_grad.is_sparse:
1433            dense_grad = dense_grad.to_dense()
1434        self.assertEqual(dense_grad, torch.zeros_like(es.weight))
1435
1436        # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
1437        N, D, B, L = (
1438            random.randint(1, 100),
1439            random.randint(1, 100),
1440            random.randint(1, 50),
1441            random.randint(1, 50),
1442        )
1443        kwargs = dict(
1444            mode=mode,
1445            sparse=sparse,
1446            device=device,
1447            wdtype=wdtype,
1448            dtype=dtype,
1449            test_backward=test_backward,
1450        )
1451        self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
1452        for max_norm in (None, 3):
1453            for p in itertools.product([1, 2], repeat=4):
1454                self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs)
1455
1456        # check that giving illegal input combos raises error
1457        es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
1458        input = torch.ones(3, 4, dtype=dtype)
1459        offset = torch.arange(0, 3, dtype=odtype)
1460        torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input, offset))
1461        torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input.view(-1)))
1462        offset[0] = 1
1463        if self.device_type == "cpu":
1464            torch._dynamo.disable(self.assertRaises)(
1465                RuntimeError, lambda: es(input.view(-1), offset)
1466            )
1467            offset[0] = 0
1468            offset[-1] = 100
1469            torch._dynamo.disable(self.assertRaises)(
1470                RuntimeError, lambda: es(input.view(-1), offset)
1471            )
1472
1473    @skipMeta
1474    @dtypes(
1475        *itertools.product(
1476            (torch.int, torch.long),
1477            (torch.int, torch.long),
1478            (torch.float, torch.double, torch.half, torch.bfloat16),
1479        )
1480    )
1481    @dtypesIfCUDA(
1482        *itertools.product(
1483            (torch.int, torch.long),
1484            (torch.int, torch.long),
1485            (torch.float, torch.double, torch.half),
1486        )
1487    )
1488    def test_embedding_bag_device(self, device, dtypes):
1489        if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu":
1490            self.skipTest("bfloat16 not supported with Jetson cpu")
1491        with set_default_dtype(torch.double):
1492            self._test_EmbeddingBag(
1493                device,
1494                "sum",
1495                False,
1496                wdtype=dtypes[2],
1497                dtype=dtypes[0],
1498                odtype=dtypes[1],
1499            )
1500            self._test_EmbeddingBag(
1501                device,
1502                "mean",
1503                False,
1504                wdtype=dtypes[2],
1505                dtype=dtypes[0],
1506                odtype=dtypes[1],
1507            )
1508            self._test_EmbeddingBag(
1509                device,
1510                "max",
1511                False,
1512                wdtype=dtypes[2],
1513                dtype=dtypes[0],
1514                odtype=dtypes[1],
1515            )
1516
1517            test_backward = False
1518            if self.device_type == "cuda":
1519                # see 'todo' in test_embedding_bag.
1520                test_backward = dtypes[2] is not torch.float16
1521            elif self.device_type == "cpu":
1522                # TODO: figure out why precision on sparse embeddings isn't the
1523                # same as for dense.
1524                test_backward = (
1525                    dtypes[2] is not torch.float and dtypes[2] is not torch.float16
1526                )
1527
1528            self._test_EmbeddingBag(
1529                device,
1530                "sum",
1531                True,
1532                wdtype=dtypes[2],
1533                dtype=dtypes[0],
1534                odtype=dtypes[1],
1535                test_backward=test_backward,
1536            )
1537            self._test_EmbeddingBag(
1538                device,
1539                "mean",
1540                True,
1541                wdtype=dtypes[2],
1542                dtype=dtypes[0],
1543                odtype=dtypes[1],
1544                test_backward=test_backward,
1545            )
1546
1547    @skipMeta
1548    @dtypes(
1549        *itertools.product(
1550            (torch.int, torch.long),
1551            (torch.int, torch.long),
1552            (torch.float, torch.double, torch.half, torch.bfloat16),
1553        )
1554    )
1555    @dtypesIfCUDA(
1556        *itertools.product(
1557            (torch.int, torch.long),
1558            (torch.int, torch.long),
1559            (torch.float, torch.double, torch.half),
1560        )
1561    )
1562    def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
1563        weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
1564
1565        weight_tensor_non_contig = weight_tensor[
1566            :, :3
1567        ]  # This is non-contiguous strided.
1568        weight_tensor_contig = (
1569            weight_tensor_non_contig.clone().contiguous()
1570        )  # Contig-strided.
1571
1572        index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device)
1573        offsets = torch.tensor([0, 2], dtype=dtypes[1], device=device)
1574        for mode in ["sum", "mean", "max"]:
1575            output_non_contig = F.embedding_bag(
1576                input=index,
1577                weight=weight_tensor_non_contig,
1578                offsets=offsets,
1579                mode=mode,
1580            )
1581            output_contig = F.embedding_bag(
1582                input=index,
1583                weight=weight_tensor_contig,
1584                offsets=offsets,
1585                mode=mode,
1586            )
1587        self.assertEqual(output_non_contig, output_contig)
1588
1589    @onlyNativeDeviceTypes  # currently fails on XLA
1590    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
1591    def test_embedding_bag_bfloat16(self, device, dtypes):
1592        with set_default_dtype(torch.double):
1593            self._test_EmbeddingBag(
1594                device,
1595                "sum",
1596                True,
1597                wdtype=torch.bfloat16,
1598                dtype=dtypes[0],
1599                odtype=dtypes[1],
1600                test_backward=True,
1601            )
1602            self._test_EmbeddingBag(
1603                device,
1604                "mean",
1605                True,
1606                wdtype=torch.bfloat16,
1607                dtype=dtypes[0],
1608                odtype=dtypes[1],
1609                test_backward=True,
1610            )
1611
1612    @onlyNativeDeviceTypes  # currently fails on XLA
1613    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
1614    def test_embedding_bag_half(self, device, dtypes):
1615        self._test_EmbeddingBag(
1616            device,
1617            "sum",
1618            True,
1619            wdtype=torch.float16,
1620            dtype=dtypes[0],
1621            odtype=dtypes[1],
1622            test_backward=True,
1623        )
1624
1625
1626instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals())
1627instantiate_parametrized_tests(TestEmbeddingNN)
1628
1629if __name__ == "__main__":
1630    run_tests()
1631