xref: /aosp_15_r20/external/pytorch/test/nn/test_embedding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport random
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
12*da0073e9SAndroid Build Coastguard Worker    dtypes,
13*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
14*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
15*da0073e9SAndroid Build Coastguard Worker    largeTensorTest,
16*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
17*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
18*da0073e9SAndroid Build Coastguard Worker    skipCUDAIf,
19*da0073e9SAndroid Build Coastguard Worker    skipMeta,
20*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
25*da0073e9SAndroid Build Coastguard Worker    _assertGradAndGradgradChecks,
26*da0073e9SAndroid Build Coastguard Worker    dtype2prec_DONTUSE,
27*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
28*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
29*da0073e9SAndroid Build Coastguard Worker    parametrize as parametrize_test,
30*da0073e9SAndroid Build Coastguard Worker    run_tests,
31*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
32*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
33*da0073e9SAndroid Build Coastguard Worker)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerclass TestEmbeddingNN(NNTestCase):
37*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
38*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
41*da0073e9SAndroid Build Coastguard Worker    def test_embedding_max_norm_unsorted_repeating_indices(self):
42*da0073e9SAndroid Build Coastguard Worker        def create_embedding(device):
43*da0073e9SAndroid Build Coastguard Worker            # Seed RNG so we get the same Embedding each time
44*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(0)
45*da0073e9SAndroid Build Coastguard Worker            return torch.nn.Embedding(
46*da0073e9SAndroid Build Coastguard Worker                num_embeddings=20, embedding_dim=64, max_norm=1.0
47*da0073e9SAndroid Build Coastguard Worker            ).to(device)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000)
50*da0073e9SAndroid Build Coastguard Worker        out_cpu = create_embedding("cpu")(ix)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        ix = ix.to("cuda")
53*da0073e9SAndroid Build Coastguard Worker        out = create_embedding("cuda")(ix)
54*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.cpu(), out_cpu)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def test_embedding_sparse_basic(self):
57*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 20, sparse=True)
58*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)
59*da0073e9SAndroid Build Coastguard Worker        embedding(input).sum().backward()
60*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(embedding.weight.grad.is_sparse)
61*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def test_embedding_sparse_empty_tensor(self):
64*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(0, 0, sparse=True)
65*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([], dtype=torch.int64)
66*da0073e9SAndroid Build Coastguard Worker        embedding(input).sum().backward()
67*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(embedding.weight.grad.is_sparse)
68*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 0, sparse=True)
71*da0073e9SAndroid Build Coastguard Worker        input = torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])
72*da0073e9SAndroid Build Coastguard Worker        embedding(input).sum().backward()
73*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(embedding.weight.grad.is_sparse)
74*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def test_move_sparse_half_embedding(self):
77*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 3, sparse=True)
78*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.device.type, "cpu")
79*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.dtype, torch.get_default_dtype())
80*da0073e9SAndroid Build Coastguard Worker        embedding.to(torch.float16)
81*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.dtype, torch.float16)
82*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.embedding_dim, 3)
83*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.num_embeddings, 10)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
86*da0073e9SAndroid Build Coastguard Worker            embedding.to("cuda")
87*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(embedding.weight.device.type, "cuda")
88*da0073e9SAndroid Build Coastguard Worker            embedding.to("cpu")
89*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(embedding.weight.device.type, "cpu")
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def test_embedding_max_norm(self):
92*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(22, 5, max_norm=1.0)
93*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([2, 8, 8, 6], dtype=torch.long)
94*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
95*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1], output[2])
96*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(output.data.norm(p=2, dim=1).le(1).all())
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    @parametrize_test(
99*da0073e9SAndroid Build Coastguard Worker        "dtype",
100*da0073e9SAndroid Build Coastguard Worker        (
101*da0073e9SAndroid Build Coastguard Worker            torch.uint8,
102*da0073e9SAndroid Build Coastguard Worker            torch.int8,
103*da0073e9SAndroid Build Coastguard Worker            torch.int16,
104*da0073e9SAndroid Build Coastguard Worker            torch.int32,
105*da0073e9SAndroid Build Coastguard Worker            torch.int64,
106*da0073e9SAndroid Build Coastguard Worker            torch.float,
107*da0073e9SAndroid Build Coastguard Worker            torch.double,
108*da0073e9SAndroid Build Coastguard Worker        ),
109*da0073e9SAndroid Build Coastguard Worker    )
110*da0073e9SAndroid Build Coastguard Worker    def test_embedding_from_pretrained(self, dtype):
111*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
112*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding.from_pretrained(a)
113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, embedding.weight.data)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        input = torch.LongTensor([0, 1])
116*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, output)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_from_pretrained(self):
120*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
121*da0073e9SAndroid Build Coastguard Worker        embedding = nn.EmbeddingBag.from_pretrained(a)
122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, embedding.weight)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([0, 1], dtype=torch.long)
125*da0073e9SAndroid Build Coastguard Worker        output = embedding(input, torch.arange(input.size(0)))
126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, output)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def test_embedding_from_pretrained_padding_idx(self):
129*da0073e9SAndroid Build Coastguard Worker        padding_idx = 2
130*da0073e9SAndroid Build Coastguard Worker        padding_vec = torch.ones(3) * 7
131*da0073e9SAndroid Build Coastguard Worker        embeddings = torch.rand(4, 3, requires_grad=True)
132*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
133*da0073e9SAndroid Build Coastguard Worker            embeddings[padding_idx] = padding_vec
134*da0073e9SAndroid Build Coastguard Worker        embedding_nn = nn.Embedding.from_pretrained(embeddings, padding_idx=padding_idx)
135*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding_nn.weight[padding_idx], padding_vec)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_from_pretrained_padding_idx(self):
138*da0073e9SAndroid Build Coastguard Worker        padding_idx = 2
139*da0073e9SAndroid Build Coastguard Worker        embeddings = torch.rand(4, 3, requires_grad=True)
140*da0073e9SAndroid Build Coastguard Worker        embedding_nn = nn.EmbeddingBag.from_pretrained(
141*da0073e9SAndroid Build Coastguard Worker            embeddings, padding_idx=padding_idx
142*da0073e9SAndroid Build Coastguard Worker        )
143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding_nn.weight, embeddings)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def test_embedding_from_pretrained_options(self):
146*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
147*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
148*da0073e9SAndroid Build Coastguard Worker            opts = {
149*da0073e9SAndroid Build Coastguard Worker                "max_norm": 2.0,
150*da0073e9SAndroid Build Coastguard Worker                "norm_type": 0.5,
151*da0073e9SAndroid Build Coastguard Worker                "scale_grad_by_freq": False,
152*da0073e9SAndroid Build Coastguard Worker                "sparse": True,
153*da0073e9SAndroid Build Coastguard Worker            }
154*da0073e9SAndroid Build Coastguard Worker            embedding = nn.Embedding.from_pretrained(a, **opts)
155*da0073e9SAndroid Build Coastguard Worker            input = torch.LongTensor([0, 1])
156*da0073e9SAndroid Build Coastguard Worker            output = embedding(input)
157*da0073e9SAndroid Build Coastguard Worker            # test output and that weight matrix was renormalized
158*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, output)
159*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
160*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
161*da0073e9SAndroid Build Coastguard Worker                output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()
162*da0073e9SAndroid Build Coastguard Worker            )
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def test_embedding_functional(self):
165*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
166*da0073e9SAndroid Build Coastguard Worker        embeddings = torch.rand(4, 3, requires_grad=True)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        embed_old = torch.nn.Embedding(4, 3)
169*da0073e9SAndroid Build Coastguard Worker        embed_old.weight.data = embeddings.data
170*da0073e9SAndroid Build Coastguard Worker        # A silly test for eager, this test is useful for when we run under PYTORCH_TEST_WITH_DYNAMO=1
171*da0073e9SAndroid Build Coastguard Worker        # as it ensures that setattr correctly works.
172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embed_old.weight.data, embeddings.data)
173*da0073e9SAndroid Build Coastguard Worker        res_old = embed_old(a)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        res_F = F.embedding(a, embeddings)
176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_old, res_F)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        embed_old = torch.nn.Embedding(4, 3)
179*da0073e9SAndroid Build Coastguard Worker        embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
180*da0073e9SAndroid Build Coastguard Worker        res_old = embed_old(a)
181*da0073e9SAndroid Build Coastguard Worker        res_F = F.embedding(a, embeddings, padding_idx=2)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_old, res_F)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/130806
186*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("40GB", device="cuda")
187*da0073e9SAndroid Build Coastguard Worker    def test_large_tensors(self):
188*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(low=0, high=16032, size=[131072], device="cuda")
189*da0073e9SAndroid Build Coastguard Worker        w = torch.randn([16032, 16384], device="cuda")
190*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.embedding(input, w)
191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.dim(), 2)
192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.numel(), 2147483648)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_functional(self):
195*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
196*da0073e9SAndroid Build Coastguard Worker        embeddings = torch.rand(4, 3, requires_grad=True)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        embed_old = torch.nn.EmbeddingBag(4, 3)
199*da0073e9SAndroid Build Coastguard Worker        embed_old.weight = torch.nn.Parameter(embeddings)
200*da0073e9SAndroid Build Coastguard Worker        res_old = embed_old(a)
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        res_F = F.embedding_bag(a, embeddings)
203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_old, res_F)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        embed_old = torch.nn.EmbeddingBag(4, 3)
206*da0073e9SAndroid Build Coastguard Worker        embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
207*da0073e9SAndroid Build Coastguard Worker        res_old = embed_old(a)
208*da0073e9SAndroid Build Coastguard Worker        res_F = F.embedding_bag(a, embeddings, padding_idx=2)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_old, res_F)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    # Make sure that error is thrown if padding_idx is out of bounds
213*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_padding_idx_error(self):
214*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long)
215*da0073e9SAndroid Build Coastguard Worker        num_embeddings = 4
216*da0073e9SAndroid Build Coastguard Worker        num_features = 3
217*da0073e9SAndroid Build Coastguard Worker        embeddings = torch.rand(num_embeddings, num_features, requires_grad=True)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        functional_err_msg = r"padding_idx must be within the number of embeddings"
220*da0073e9SAndroid Build Coastguard Worker        module_err_msg = r"padding_idx must be within num_embeddings"
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        for padding_idx in range(-(num_embeddings + 2), (num_embeddings + 2)):
223*da0073e9SAndroid Build Coastguard Worker            if (padding_idx < -num_embeddings) or (padding_idx >= num_embeddings):
224*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, functional_err_msg):
225*da0073e9SAndroid Build Coastguard Worker                    F.embedding_bag(a, embeddings, padding_idx=padding_idx)
226*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(AssertionError, module_err_msg):
227*da0073e9SAndroid Build Coastguard Worker                    torch.nn.EmbeddingBag(
228*da0073e9SAndroid Build Coastguard Worker                        num_embeddings, num_features, padding_idx=padding_idx
229*da0073e9SAndroid Build Coastguard Worker                    )
230*da0073e9SAndroid Build Coastguard Worker            else:
231*da0073e9SAndroid Build Coastguard Worker                F.embedding_bag(a, embeddings, padding_idx=padding_idx)
232*da0073e9SAndroid Build Coastguard Worker                torch.nn.EmbeddingBag(
233*da0073e9SAndroid Build Coastguard Worker                    num_embeddings, num_features, padding_idx=padding_idx
234*da0073e9SAndroid Build Coastguard Worker                )
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    def test_embeddingbag_from_pretrained(self):
237*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
238*da0073e9SAndroid Build Coastguard Worker        embeddingbag = nn.EmbeddingBag.from_pretrained(a)
239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, embeddingbag.weight.data)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        input = torch.LongTensor([[0, 1]])
242*da0073e9SAndroid Build Coastguard Worker        output = embeddingbag(input)
243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.mean(0, keepdim=True), output)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    def test_embeddingbag_from_pretrained_options(self):
246*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
247*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
248*da0073e9SAndroid Build Coastguard Worker            opts = {
249*da0073e9SAndroid Build Coastguard Worker                "max_norm": 2.0,
250*da0073e9SAndroid Build Coastguard Worker                "norm_type": 0.5,
251*da0073e9SAndroid Build Coastguard Worker                "scale_grad_by_freq": False,
252*da0073e9SAndroid Build Coastguard Worker                "mode": "max",
253*da0073e9SAndroid Build Coastguard Worker                "sparse": False,
254*da0073e9SAndroid Build Coastguard Worker            }
255*da0073e9SAndroid Build Coastguard Worker            embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker            input = torch.LongTensor([[0, 1]])
258*da0073e9SAndroid Build Coastguard Worker            output = embeddingbag(input)
259*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.max(0, keepdim=True)[0], output)
260*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all())
261*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
262*da0073e9SAndroid Build Coastguard Worker                a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all()
263*da0073e9SAndroid Build Coastguard Worker            )
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    def test_embeddingbag_include_last_offset(self):
266*da0073e9SAndroid Build Coastguard Worker        # Test case from https://github.com/pytorch/pytorch/issues/89677
267*da0073e9SAndroid Build Coastguard Worker        embeddingbag = nn.EmbeddingBag(100, 3, include_last_offset=True, padding_idx=61)
268*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([0, 1, 2, 3])
269*da0073e9SAndroid Build Coastguard Worker        out = embeddingbag(input, torch.tensor([0, 3, 3]))
270*da0073e9SAndroid Build Coastguard Worker        out2 = embeddingbag(input, torch.tensor([0, 3, 4]))
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        weight = embeddingbag.weight
273*da0073e9SAndroid Build Coastguard Worker        row0 = weight[0:3].mean(0)
274*da0073e9SAndroid Build Coastguard Worker        row1 = weight[3]
275*da0073e9SAndroid Build Coastguard Worker        ref_out = torch.stack([row0, row1])
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_out, out)
278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_out, out2)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Workerclass TestEmbeddingNNDeviceType(NNTestCase):
282*da0073e9SAndroid Build Coastguard Worker    def test_embedding_dense_grad(self, device):
283*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
284*da0073e9SAndroid Build Coastguard Worker            embd = nn.Embedding(20, 20).to(device)
285*da0073e9SAndroid Build Coastguard Worker            weight = embd.weight
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker            def fn_wrapper(device):
288*da0073e9SAndroid Build Coastguard Worker                def fn(weight):
289*da0073e9SAndroid Build Coastguard Worker                    inp = torch.tensor(
290*da0073e9SAndroid Build Coastguard Worker                        [[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long
291*da0073e9SAndroid Build Coastguard Worker                    ).to(device)
292*da0073e9SAndroid Build Coastguard Worker                    return torch.nn.functional.embedding(inp, weight)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker                return fn
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker            fn = fn_wrapper(device)
297*da0073e9SAndroid Build Coastguard Worker            _assertGradAndGradgradChecks(self, fn, (weight,))
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    def test_embedding_scalar_weight_error(self, device):
300*da0073e9SAndroid Build Coastguard Worker        indices = torch.rand(2, 2, device=device).long()
301*da0073e9SAndroid Build Coastguard Worker        weights = [
302*da0073e9SAndroid Build Coastguard Worker            torch.tensor(1.0, device=device),
303*da0073e9SAndroid Build Coastguard Worker            torch.tensor(1.0, device=device).reshape(1, 1, 1),
304*da0073e9SAndroid Build Coastguard Worker        ]
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        for weight in weights:
307*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"):
308*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.embedding(indices, weight)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float16, torch.float64)
311*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64)
312*da0073e9SAndroid Build Coastguard Worker    def test_embedding_backward(self, device, dtype):
313*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 3, sparse=True)
314*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([[7, 1, 3]])
315*da0073e9SAndroid Build Coastguard Worker        ones = torch.tensor(1.0, dtype=dtype).expand(3, 3)
316*da0073e9SAndroid Build Coastguard Worker        tensorTwice = tensor.repeat(1, 2)
317*da0073e9SAndroid Build Coastguard Worker        onesTwice = torch.cat((ones, ones))
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        embedding = embedding.to(dtype=dtype).to(device)
320*da0073e9SAndroid Build Coastguard Worker        tensor = tensor.to(device)
321*da0073e9SAndroid Build Coastguard Worker        ones = ones.to(device)
322*da0073e9SAndroid Build Coastguard Worker        tensorTwice = tensorTwice.to(device)
323*da0073e9SAndroid Build Coastguard Worker        onesTwice = onesTwice.to(device)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        embedding.zero_grad()
326*da0073e9SAndroid Build Coastguard Worker        embedding(tensor[0]).sum().backward()
327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad._indices(), tensor)
328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad._values(), ones)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        embedding.zero_grad()
331*da0073e9SAndroid Build Coastguard Worker        embedding(tensor[0]).sum().backward()
332*da0073e9SAndroid Build Coastguard Worker        embedding(tensor[0]).sum().backward()
333*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad._values(), onesTwice)
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        embedding.zero_grad()
337*da0073e9SAndroid Build Coastguard Worker        embedding(tensor[0]).sum().backward()
338*da0073e9SAndroid Build Coastguard Worker        tensor[0, 0] = 8
339*da0073e9SAndroid Build Coastguard Worker        embedding(tensor[0]).sum().backward()
340*da0073e9SAndroid Build Coastguard Worker        tensorTwice[0, 3] = 8
341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(embedding.weight.grad._values(), onesTwice)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
345*da0073e9SAndroid Build Coastguard Worker        *(
346*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.bfloat16, torch.half)
347*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM
348*da0073e9SAndroid Build Coastguard Worker            else (torch.float, torch.double, torch.half)
349*da0073e9SAndroid Build Coastguard Worker        )
350*da0073e9SAndroid Build Coastguard Worker    )
351*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
352*da0073e9SAndroid Build Coastguard Worker    def test_embedding_max_norm_backward(self, device, dtype):
353*da0073e9SAndroid Build Coastguard Worker        # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
354*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn((4, 4), device=device, dtype=dtype) * 2
355*da0073e9SAndroid Build Coastguard Worker        weight.requires_grad_()
356*da0073e9SAndroid Build Coastguard Worker        inp_list = [0, 1, 2, 2]
357*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor(inp_list, device=device)
358*da0073e9SAndroid Build Coastguard Worker        out = nn.functional.embedding(inp, weight, max_norm=1.0).sum()
359*da0073e9SAndroid Build Coastguard Worker        out.backward()
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        expected_grad = (
362*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1.0, 1.0, 2.0, 0.0]], device=device, dtype=dtype)
363*da0073e9SAndroid Build Coastguard Worker            .transpose(0, 1)
364*da0073e9SAndroid Build Coastguard Worker            .expand(4, 4)
365*da0073e9SAndroid Build Coastguard Worker        )
366*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(weight.grad, expected_grad)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
369*da0073e9SAndroid Build Coastguard Worker        *(
370*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.bfloat16, torch.half)
371*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM
372*da0073e9SAndroid Build Coastguard Worker            else (torch.float, torch.double, torch.half)
373*da0073e9SAndroid Build Coastguard Worker        )
374*da0073e9SAndroid Build Coastguard Worker    )
375*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
376*da0073e9SAndroid Build Coastguard Worker    def test_embedding_max_norm_fwd_AD(self, device, dtype):
377*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == "xla":
378*da0073e9SAndroid Build Coastguard Worker            self.skipTest("forward AD doesn't work on xla")
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
381*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn((4, 4), device=device, dtype=dtype) * 2
382*da0073e9SAndroid Build Coastguard Worker        tangent = torch.ones((4, 4), device=device, dtype=dtype)
383*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([[0, 1], [2, 2]], device=device)
384*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.forward_ad.dual_level():
385*da0073e9SAndroid Build Coastguard Worker            dual_weight = torch.autograd.forward_ad.make_dual(weight, tangent)
386*da0073e9SAndroid Build Coastguard Worker            out = nn.functional.embedding(inp, dual_weight, max_norm=1.0)
387*da0073e9SAndroid Build Coastguard Worker            jvp = torch.autograd.forward_ad.unpack_dual(out).tangent
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.ones((2, 2, 4), device=device, dtype=dtype)
390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jvp, expected_grad)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
393*da0073e9SAndroid Build Coastguard Worker        *(
394*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.bfloat16, torch.half)
395*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM
396*da0073e9SAndroid Build Coastguard Worker            else (torch.float, torch.double, torch.half)
397*da0073e9SAndroid Build Coastguard Worker        )
398*da0073e9SAndroid Build Coastguard Worker    )
399*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
400*da0073e9SAndroid Build Coastguard Worker    def test_embedding_padding_idx(self, device, dtype):
401*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype)
402*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
403*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0][0].sum(), 0)
405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1][2].sum(), 0)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True).to(device, dtype)
408*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
409*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0][0].sum(), 0)
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1][2].sum(), 0)
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        # negative indexing check for padding_idx
414*da0073e9SAndroid Build Coastguard Worker        # padding_idx=-2, num_embeddings=10 ==> index 8 padded
415*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 20, padding_idx=-2).to(device, dtype)
416*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
417*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0][2].sum(), 0)
419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1][1].sum(), 0)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True).to(device, dtype)
422*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
423*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0][2].sum(), 0)
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1][1].sum(), 0)
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        # change padding vector
428*da0073e9SAndroid Build Coastguard Worker        padding_vector = torch.ones(20, dtype=dtype, device=device)
429*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(10, 20, padding_idx=2, sparse=True).to(device, dtype)
430*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
431*da0073e9SAndroid Build Coastguard Worker            embedding.weight[2] = padding_vector
432*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([0, 2], dtype=torch.long).to(device)
433*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1], padding_vector)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker        # out of bounds check for padding_idx
437*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
438*da0073e9SAndroid Build Coastguard Worker            AssertionError,
439*da0073e9SAndroid Build Coastguard Worker            nn.Embedding,
440*da0073e9SAndroid Build Coastguard Worker            num_embeddings=10,
441*da0073e9SAndroid Build Coastguard Worker            embedding_dim=20,
442*da0073e9SAndroid Build Coastguard Worker            padding_idx=25,
443*da0073e9SAndroid Build Coastguard Worker        )
444*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
445*da0073e9SAndroid Build Coastguard Worker            AssertionError,
446*da0073e9SAndroid Build Coastguard Worker            nn.Embedding,
447*da0073e9SAndroid Build Coastguard Worker            num_embeddings=10,
448*da0073e9SAndroid Build Coastguard Worker            embedding_dim=20,
449*da0073e9SAndroid Build Coastguard Worker            padding_idx=-25,
450*da0073e9SAndroid Build Coastguard Worker        )
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        padding_idx = 0
453*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(5, 2, padding_idx=padding_idx).to(device, dtype)
454*da0073e9SAndroid Build Coastguard Worker        for n in (
455*da0073e9SAndroid Build Coastguard Worker            1,
456*da0073e9SAndroid Build Coastguard Worker            2,
457*da0073e9SAndroid Build Coastguard Worker            1000,
458*da0073e9SAndroid Build Coastguard Worker        ):  # Need large N to trigger all the methods we have implemented
459*da0073e9SAndroid Build Coastguard Worker            for other_indices in ([], [1, 3], [2]):
460*da0073e9SAndroid Build Coastguard Worker                indices = torch.tensor(
461*da0073e9SAndroid Build Coastguard Worker                    other_indices + [padding_idx] * n, dtype=torch.long
462*da0073e9SAndroid Build Coastguard Worker                ).to(device)
463*da0073e9SAndroid Build Coastguard Worker                pre = embedding.weight[padding_idx].clone()
464*da0073e9SAndroid Build Coastguard Worker                embedding(indices).sum().backward()
465*da0073e9SAndroid Build Coastguard Worker                after = (embedding.weight + embedding.weight.grad)[padding_idx]
466*da0073e9SAndroid Build Coastguard Worker                embedding.zero_grad()
467*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(after, pre)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker                # test double backward
470*da0073e9SAndroid Build Coastguard Worker                emb_sum = embedding(indices).sum()
471*da0073e9SAndroid Build Coastguard Worker                emb_grad = torch.autograd.grad(
472*da0073e9SAndroid Build Coastguard Worker                    outputs=emb_sum,
473*da0073e9SAndroid Build Coastguard Worker                    inputs=list(embedding.parameters()),
474*da0073e9SAndroid Build Coastguard Worker                    retain_graph=True,
475*da0073e9SAndroid Build Coastguard Worker                )
476*da0073e9SAndroid Build Coastguard Worker                scalar = emb_grad[0].sum() + emb_sum
477*da0073e9SAndroid Build Coastguard Worker                scalar.backward()
478*da0073e9SAndroid Build Coastguard Worker                after = (embedding.weight + embedding.weight.grad)[padding_idx]
479*da0073e9SAndroid Build Coastguard Worker                embedding.zero_grad()
480*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(after, pre)
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker    # Check correctness of torch.nn.functional.embedding_bag forward and
483*da0073e9SAndroid Build Coastguard Worker    # backward functions with padding_idx, given a 1D input separated into bags
484*da0073e9SAndroid Build Coastguard Worker    # with an offset array. Compare against an equivalent 2D input that uses
485*da0073e9SAndroid Build Coastguard Worker    # padding indices to fill in the gaps indicated by the offset array
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("see https://github.com/pytorch/pytorch/pull/95621")
488*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
489*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
490*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.bfloat16)
491*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_1D_padding_idx(self, device, dtype):
492*da0073e9SAndroid Build Coastguard Worker        num_features = 3
493*da0073e9SAndroid Build Coastguard Worker        max_indices_per_bag = 10
494*da0073e9SAndroid Build Coastguard Worker        num_bags = 10
495*da0073e9SAndroid Build Coastguard Worker        num_words = 100
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        def gen_1D_indices_offsets(include_last_offset, allpad):
498*da0073e9SAndroid Build Coastguard Worker            indices = []
499*da0073e9SAndroid Build Coastguard Worker            offsets = []
500*da0073e9SAndroid Build Coastguard Worker            cur_offset = 0
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            # Make one bag full and one bag empty, for extra coverage
503*da0073e9SAndroid Build Coastguard Worker            empty_bag = random.randint(0, num_bags - 1)
504*da0073e9SAndroid Build Coastguard Worker            full_bag = empty_bag
505*da0073e9SAndroid Build Coastguard Worker            while full_bag == empty_bag:
506*da0073e9SAndroid Build Coastguard Worker                full_bag = random.randint(0, num_bags - 1)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker            for bag in range(num_bags):
509*da0073e9SAndroid Build Coastguard Worker                offsets.append(cur_offset)
510*da0073e9SAndroid Build Coastguard Worker                if bag == full_bag:
511*da0073e9SAndroid Build Coastguard Worker                    bag_size = max_indices_per_bag
512*da0073e9SAndroid Build Coastguard Worker                elif bag == empty_bag:
513*da0073e9SAndroid Build Coastguard Worker                    bag_size = 0
514*da0073e9SAndroid Build Coastguard Worker                else:
515*da0073e9SAndroid Build Coastguard Worker                    bag_size = random.randint(1, max_indices_per_bag - 1)
516*da0073e9SAndroid Build Coastguard Worker                indices += [
517*da0073e9SAndroid Build Coastguard Worker                    1 if allpad else random.randint(0, num_words - 1)
518*da0073e9SAndroid Build Coastguard Worker                    for _ in range(bag_size)
519*da0073e9SAndroid Build Coastguard Worker                ]
520*da0073e9SAndroid Build Coastguard Worker                cur_offset += bag_size
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker            # embedding_bag requires first entry of offsets to be 0
523*da0073e9SAndroid Build Coastguard Worker            assert offsets[0] == 0
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker            indices = torch.tensor(indices, device=device)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker            if include_last_offset:
528*da0073e9SAndroid Build Coastguard Worker                offsets.append(indices.size(0))
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor(offsets, device=device)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker            return indices, offsets
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        # Convert a 1-D indices-offsets representation into 2-D. Fill any empty
535*da0073e9SAndroid Build Coastguard Worker        # indices with padding_idx
536*da0073e9SAndroid Build Coastguard Worker        def gen_2D_indices_from_1D(
537*da0073e9SAndroid Build Coastguard Worker            indices_1D, offsets, include_last_offset, padding_idx
538*da0073e9SAndroid Build Coastguard Worker        ):
539*da0073e9SAndroid Build Coastguard Worker            assert offsets[0] == 0
540*da0073e9SAndroid Build Coastguard Worker            if include_last_offset:
541*da0073e9SAndroid Build Coastguard Worker                offsets = offsets[:-1]
542*da0073e9SAndroid Build Coastguard Worker            indices_2D = torch.empty(
543*da0073e9SAndroid Build Coastguard Worker                num_bags, max_indices_per_bag, device=device, dtype=torch.long
544*da0073e9SAndroid Build Coastguard Worker            )
545*da0073e9SAndroid Build Coastguard Worker            for bag in range(num_bags):
546*da0073e9SAndroid Build Coastguard Worker                # Determine the start and end position of the bag within indices_1D
547*da0073e9SAndroid Build Coastguard Worker                start = offsets[bag]
548*da0073e9SAndroid Build Coastguard Worker                end = len(indices_1D) if bag + 1 == num_bags else offsets[bag + 1]
549*da0073e9SAndroid Build Coastguard Worker                end = min(len(indices_1D), end)
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker                # Pull out the bag's indices from indices_1D, and fill any
552*da0073e9SAndroid Build Coastguard Worker                # remaining space with padding indices
553*da0073e9SAndroid Build Coastguard Worker                indices_in_bag = []
554*da0073e9SAndroid Build Coastguard Worker                for item_pos in range(0, max_indices_per_bag):
555*da0073e9SAndroid Build Coastguard Worker                    if (start + item_pos) < end:
556*da0073e9SAndroid Build Coastguard Worker                        indices_in_bag.append(indices_1D[start + item_pos])
557*da0073e9SAndroid Build Coastguard Worker                    else:
558*da0073e9SAndroid Build Coastguard Worker                        indices_in_bag.append(padding_idx)
559*da0073e9SAndroid Build Coastguard Worker                indices_2D[bag] = torch.tensor(indices_in_bag, device=device)
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker            return indices_2D
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker        test_cases = product(
564*da0073e9SAndroid Build Coastguard Worker            ["max", "mean", "sum"], [False, True], [False, True], [False, True]
565*da0073e9SAndroid Build Coastguard Worker        )
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker        for mode, sparse, include_last_offset, allpad in test_cases:
568*da0073e9SAndroid Build Coastguard Worker            # Max sparse and bfloat16 are not supported
569*da0073e9SAndroid Build Coastguard Worker            if mode == "max":
570*da0073e9SAndroid Build Coastguard Worker                if sparse or (dtype == torch.bfloat16):
571*da0073e9SAndroid Build Coastguard Worker                    continue
572*da0073e9SAndroid Build Coastguard Worker            indices_1D, offsets = gen_1D_indices_offsets(include_last_offset, allpad)
573*da0073e9SAndroid Build Coastguard Worker            for padding_idx_1D in list(set(indices_1D.tolist())) + [None]:
574*da0073e9SAndroid Build Coastguard Worker                msg = (
575*da0073e9SAndroid Build Coastguard Worker                    f"mode: '{mode}', sparse: {sparse}, include_last_offset: {include_last_offset}, "
576*da0073e9SAndroid Build Coastguard Worker                    f"padding_idx_1D: {padding_idx_1D}"
577*da0073e9SAndroid Build Coastguard Worker                )
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker                # If 1D input does not use a padding index, we still need one for the 2D input,
580*da0073e9SAndroid Build Coastguard Worker                # so we can add one dummy word to the weights to act as the padded word
581*da0073e9SAndroid Build Coastguard Worker                padding_idx_2D = (
582*da0073e9SAndroid Build Coastguard Worker                    padding_idx_1D if padding_idx_1D is not None else num_words
583*da0073e9SAndroid Build Coastguard Worker                )
584*da0073e9SAndroid Build Coastguard Worker                num_words_with_padding = (
585*da0073e9SAndroid Build Coastguard Worker                    num_words if padding_idx_1D is not None else num_words + 1
586*da0073e9SAndroid Build Coastguard Worker                )
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker                indices_2D = gen_2D_indices_from_1D(
589*da0073e9SAndroid Build Coastguard Worker                    indices_1D, offsets, include_last_offset, padding_idx_2D
590*da0073e9SAndroid Build Coastguard Worker                )
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker                weights = torch.randn(
593*da0073e9SAndroid Build Coastguard Worker                    num_words_with_padding,
594*da0073e9SAndroid Build Coastguard Worker                    num_features,
595*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
596*da0073e9SAndroid Build Coastguard Worker                    device=device,
597*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
598*da0073e9SAndroid Build Coastguard Worker                )
599*da0073e9SAndroid Build Coastguard Worker                weights_check = weights.clone().detach().requires_grad_(True)
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker                bag = torch.nn.functional.embedding_bag(
602*da0073e9SAndroid Build Coastguard Worker                    indices_1D,
603*da0073e9SAndroid Build Coastguard Worker                    weights,
604*da0073e9SAndroid Build Coastguard Worker                    offsets,
605*da0073e9SAndroid Build Coastguard Worker                    padding_idx=padding_idx_1D,
606*da0073e9SAndroid Build Coastguard Worker                    mode=mode,
607*da0073e9SAndroid Build Coastguard Worker                    sparse=sparse,
608*da0073e9SAndroid Build Coastguard Worker                    include_last_offset=include_last_offset,
609*da0073e9SAndroid Build Coastguard Worker                )
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker                bag_check = torch.nn.functional.embedding_bag(
612*da0073e9SAndroid Build Coastguard Worker                    indices_2D,
613*da0073e9SAndroid Build Coastguard Worker                    weights_check,
614*da0073e9SAndroid Build Coastguard Worker                    padding_idx=padding_idx_2D,
615*da0073e9SAndroid Build Coastguard Worker                    mode=mode,
616*da0073e9SAndroid Build Coastguard Worker                    sparse=sparse,
617*da0073e9SAndroid Build Coastguard Worker                )
618*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bag, bag_check, msg=msg)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker                bag.sum().backward()
621*da0073e9SAndroid Build Coastguard Worker                bag_check.sum().backward()
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker                # Sometimes, half dtype gradients mismatch by a greater amount
624*da0073e9SAndroid Build Coastguard Worker                # than other dtypes
625*da0073e9SAndroid Build Coastguard Worker                if dtype in [torch.half, torch.bfloat16]:
626*da0073e9SAndroid Build Coastguard Worker                    atol = 0.01
627*da0073e9SAndroid Build Coastguard Worker                    rtol = 0.01
628*da0073e9SAndroid Build Coastguard Worker                else:
629*da0073e9SAndroid Build Coastguard Worker                    atol = None
630*da0073e9SAndroid Build Coastguard Worker                    rtol = None
631*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
632*da0073e9SAndroid Build Coastguard Worker                    weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol
633*da0073e9SAndroid Build Coastguard Worker                )
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker    # Check correctness of torch.nn.functional.embedding_bag forward and
636*da0073e9SAndroid Build Coastguard Worker    # backward functions with padding_idx, given a 2D indices input. Compare
637*da0073e9SAndroid Build Coastguard Worker    # against torch.nn.functional.embedding followed by a reduction.
638*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
639*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
640*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.bfloat16)
641*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_2D_padding_idx(self, device, dtype):
642*da0073e9SAndroid Build Coastguard Worker        # Use a Python implementation of embedding_bag with padding_idx support
643*da0073e9SAndroid Build Coastguard Worker        # to check torch.nn.functional.embedding_bag correctness
644*da0073e9SAndroid Build Coastguard Worker        def embedding_bag_check(indices, weights, mode, sparse, padding_idx):
645*da0073e9SAndroid Build Coastguard Worker            assert padding_idx is not None
646*da0073e9SAndroid Build Coastguard Worker            embedding = torch.nn.functional.embedding(
647*da0073e9SAndroid Build Coastguard Worker                indices, weights, padding_idx=padding_idx, sparse=sparse
648*da0073e9SAndroid Build Coastguard Worker            )
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker            reduction_dim = indices.dim() - 1
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker            if mode == "sum" or mode == "mean":
653*da0073e9SAndroid Build Coastguard Worker                # We must avoid including elements at padding_idx in the
654*da0073e9SAndroid Build Coastguard Worker                # sum/mean, so multiply those elements by 0, and multiply
655*da0073e9SAndroid Build Coastguard Worker                # all other elements by 1
656*da0073e9SAndroid Build Coastguard Worker                per_sample_weights = indices.ne(padding_idx).to(dtype).unsqueeze(-1)
657*da0073e9SAndroid Build Coastguard Worker                res = embedding.mul(per_sample_weights).sum(dim=reduction_dim)
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker                if mode == "mean":
660*da0073e9SAndroid Build Coastguard Worker                    weights_sum = per_sample_weights.sum(dim=reduction_dim)
661*da0073e9SAndroid Build Coastguard Worker                    res = res.div(weights_sum)
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker            elif mode == "max":
664*da0073e9SAndroid Build Coastguard Worker                # We must avoid allowing elements at padding_idx to be chosen
665*da0073e9SAndroid Build Coastguard Worker                # as the max, so set those elements to negative infinity
666*da0073e9SAndroid Build Coastguard Worker                res = embedding.masked_fill(
667*da0073e9SAndroid Build Coastguard Worker                    indices.unsqueeze(-1) == padding_idx, -float("inf")
668*da0073e9SAndroid Build Coastguard Worker                ).amax(dim=reduction_dim)
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker            else:
671*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"mode '{mode}' is not available")
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker            # If a row is all padding, set its corresponding result row to 0.
674*da0073e9SAndroid Build Coastguard Worker            # This is needed because the above mean and max mode
675*da0073e9SAndroid Build Coastguard Worker            # implementations set these elements to nan and -inf, respectively
676*da0073e9SAndroid Build Coastguard Worker            if mode in ["mean", "max"]:
677*da0073e9SAndroid Build Coastguard Worker                res = res.masked_fill(
678*da0073e9SAndroid Build Coastguard Worker                    indices.eq(padding_idx).all(dim=-1).unsqueeze(-1), 0
679*da0073e9SAndroid Build Coastguard Worker                )
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker            return res
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        num_features = 3
684*da0073e9SAndroid Build Coastguard Worker        num_words = 10
685*da0073e9SAndroid Build Coastguard Worker        indices_dim1 = 10
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        for mode, sparse, allpad, indices_dim0 in product(
688*da0073e9SAndroid Build Coastguard Worker            ["max", "mean", "sum"], [False, True], [False, True], [1, 10]
689*da0073e9SAndroid Build Coastguard Worker        ):
690*da0073e9SAndroid Build Coastguard Worker            # Max sparse and bfloat16 are not supported
691*da0073e9SAndroid Build Coastguard Worker            if mode == "max":
692*da0073e9SAndroid Build Coastguard Worker                if sparse or (dtype == torch.bfloat16):
693*da0073e9SAndroid Build Coastguard Worker                    continue
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker            if allpad:
696*da0073e9SAndroid Build Coastguard Worker                indices = torch.empty(
697*da0073e9SAndroid Build Coastguard Worker                    indices_dim0, indices_dim1, dtype=torch.long, device=device
698*da0073e9SAndroid Build Coastguard Worker                ).fill_(1)
699*da0073e9SAndroid Build Coastguard Worker            else:
700*da0073e9SAndroid Build Coastguard Worker                indices = torch.randint(
701*da0073e9SAndroid Build Coastguard Worker                    0, num_words, (indices_dim0, indices_dim1), device=device
702*da0073e9SAndroid Build Coastguard Worker                )
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker                if indices_dim0 > 1:
705*da0073e9SAndroid Build Coastguard Worker                    # Fill one row with duplicate index so we can test with a fully
706*da0073e9SAndroid Build Coastguard Worker                    # padded row
707*da0073e9SAndroid Build Coastguard Worker                    duplicate_row = random.randint(0, indices_dim0 - 1)
708*da0073e9SAndroid Build Coastguard Worker                    indices[duplicate_row] = indices[duplicate_row][0]
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker            for padding_idx in list(set(indices.flatten(0, -1).tolist())):
711*da0073e9SAndroid Build Coastguard Worker                weights = torch.randn(
712*da0073e9SAndroid Build Coastguard Worker                    num_words,
713*da0073e9SAndroid Build Coastguard Worker                    num_features,
714*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
715*da0073e9SAndroid Build Coastguard Worker                    device=device,
716*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
717*da0073e9SAndroid Build Coastguard Worker                )
718*da0073e9SAndroid Build Coastguard Worker                weights_check = weights.clone().detach().requires_grad_(True)
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker                msg = (
721*da0073e9SAndroid Build Coastguard Worker                    f"mode: '{mode}', sparse: {sparse}, padding_idx: {padding_idx}, "
722*da0073e9SAndroid Build Coastguard Worker                    f"allpad: {allpad}, indices.size(): {indices.size()}"
723*da0073e9SAndroid Build Coastguard Worker                )
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker                # Check forward with a Python implementation of padding_idx embedding_bag
726*da0073e9SAndroid Build Coastguard Worker                bag_check = embedding_bag_check(
727*da0073e9SAndroid Build Coastguard Worker                    indices, weights_check, mode, sparse, padding_idx
728*da0073e9SAndroid Build Coastguard Worker                )
729*da0073e9SAndroid Build Coastguard Worker                bag = torch.nn.functional.embedding_bag(
730*da0073e9SAndroid Build Coastguard Worker                    indices, weights, padding_idx=padding_idx, mode=mode, sparse=sparse
731*da0073e9SAndroid Build Coastguard Worker                )
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bag, bag_check, msg=msg)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker                bag_check.sum().backward()
736*da0073e9SAndroid Build Coastguard Worker                grad_check = weights_check.grad
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker                bag.sum().backward()
739*da0073e9SAndroid Build Coastguard Worker                grad = weights.grad
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker                # Sometimes, half dtype gradients mismatch by a greater amount
742*da0073e9SAndroid Build Coastguard Worker                # than other dtypes
743*da0073e9SAndroid Build Coastguard Worker                if dtype in [torch.half, torch.bfloat16]:
744*da0073e9SAndroid Build Coastguard Worker                    atol = 0.01
745*da0073e9SAndroid Build Coastguard Worker                    rtol = 0.01
746*da0073e9SAndroid Build Coastguard Worker                else:
747*da0073e9SAndroid Build Coastguard Worker                    atol = None
748*da0073e9SAndroid Build Coastguard Worker                    rtol = None
749*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol)
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
752*da0073e9SAndroid Build Coastguard Worker    @dtypes(
753*da0073e9SAndroid Build Coastguard Worker        *(
754*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.bfloat16, torch.half)
755*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM
756*da0073e9SAndroid Build Coastguard Worker            else (torch.float, torch.double, torch.half)
757*da0073e9SAndroid Build Coastguard Worker        )
758*da0073e9SAndroid Build Coastguard Worker    )
759*da0073e9SAndroid Build Coastguard Worker    def test_embedding_max_norm_device(self, device, dtype):
760*da0073e9SAndroid Build Coastguard Worker        embedding = nn.Embedding(22, 5, max_norm=1.0).to(device, dtype=dtype)
761*da0073e9SAndroid Build Coastguard Worker        # nn.Embedding only takes LongTensor as input
762*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([2, 8, 8, 6], device=device, dtype=torch.long)
763*da0073e9SAndroid Build Coastguard Worker        output = embedding(input)
764*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[1], output[2])
765*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(output.data.norm(p=2, dim=1).le(1).all())
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
768*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_empty_input(self, device, dtypes):
769*da0073e9SAndroid Build Coastguard Worker        m = 4
770*da0073e9SAndroid Build Coastguard Worker        n = 3
771*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([], device=device, dtype=dtypes[0])
772*da0073e9SAndroid Build Coastguard Worker        for sparse in [True, False]:
773*da0073e9SAndroid Build Coastguard Worker            Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
774*da0073e9SAndroid Build Coastguard Worker            Embed.to(device)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker            output = Embed(
777*da0073e9SAndroid Build Coastguard Worker                input=x, offsets=torch.tensor([0], device=device, dtype=dtypes[1])
778*da0073e9SAndroid Build Coastguard Worker            )
779*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, torch.zeros_like(output))
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker            output = Embed(
782*da0073e9SAndroid Build Coastguard Worker                input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtypes[1])
783*da0073e9SAndroid Build Coastguard Worker            )
784*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, torch.zeros_like(output))
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.")
787*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long)))
788*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("padding_idx", [None, 0])
789*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["sum", "mean", "max"])
790*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_out_of_bounds_idx(self, device, dtypes, padding_idx, mode):
791*da0073e9SAndroid Build Coastguard Worker        padding_idx = 0
792*da0073e9SAndroid Build Coastguard Worker        w_dtype, idx_dtype = dtypes
793*da0073e9SAndroid Build Coastguard Worker        # negative out-of-bound
794*da0073e9SAndroid Build Coastguard Worker        idx1 = torch.tensor([[-1, 1]], device=device, dtype=idx_dtype)
795*da0073e9SAndroid Build Coastguard Worker        # positive out-of-bound
796*da0073e9SAndroid Build Coastguard Worker        idx2 = torch.tensor([[11, 8]], device=device, dtype=idx_dtype)
797*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(10, 2, device=device, dtype=w_dtype)
798*da0073e9SAndroid Build Coastguard Worker        if mode == "sum":
799*da0073e9SAndroid Build Coastguard Worker            # Only `sum` supports per_sample_weight
800*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = (
801*da0073e9SAndroid Build Coastguard Worker                None,
802*da0073e9SAndroid Build Coastguard Worker                torch.randn_like(idx1, device=device, dtype=w_dtype),
803*da0073e9SAndroid Build Coastguard Worker            )
804*da0073e9SAndroid Build Coastguard Worker        else:
805*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = (None,)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        for p_s_weights, idx in itertools.product(per_sample_weights, (idx1, idx2)):
808*da0073e9SAndroid Build Coastguard Worker            msg = "Expected idx >= 0 && idx < num_embeddings"
809*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
810*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.embedding_bag(
811*da0073e9SAndroid Build Coastguard Worker                    idx,
812*da0073e9SAndroid Build Coastguard Worker                    weight,
813*da0073e9SAndroid Build Coastguard Worker                    per_sample_weights=p_s_weights,
814*da0073e9SAndroid Build Coastguard Worker                    padding_idx=padding_idx,
815*da0073e9SAndroid Build Coastguard Worker                    mode=mode,
816*da0073e9SAndroid Build Coastguard Worker                )
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_dimension_errors(self, device):
819*da0073e9SAndroid Build Coastguard Worker        funcs = (
820*da0073e9SAndroid Build Coastguard Worker            lambda x, y, z: torch.nn.functional.embedding_bag(y, x, z),
821*da0073e9SAndroid Build Coastguard Worker            torch.embedding_bag,
822*da0073e9SAndroid Build Coastguard Worker            torch._embedding_bag,
823*da0073e9SAndroid Build Coastguard Worker            torch._embedding_bag_forward_only,
824*da0073e9SAndroid Build Coastguard Worker        )
825*da0073e9SAndroid Build Coastguard Worker        for i, f in enumerate(funcs):
826*da0073e9SAndroid Build Coastguard Worker            err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker            weight = torch.full(
829*da0073e9SAndroid Build Coastguard Worker                (
830*da0073e9SAndroid Build Coastguard Worker                    2,
831*da0073e9SAndroid Build Coastguard Worker                    6,
832*da0073e9SAndroid Build Coastguard Worker                ),
833*da0073e9SAndroid Build Coastguard Worker                0,
834*da0073e9SAndroid Build Coastguard Worker                dtype=torch.float64,
835*da0073e9SAndroid Build Coastguard Worker                device=device,
836*da0073e9SAndroid Build Coastguard Worker            )
837*da0073e9SAndroid Build Coastguard Worker            indices = torch.full(
838*da0073e9SAndroid Build Coastguard Worker                (
839*da0073e9SAndroid Build Coastguard Worker                    2,
840*da0073e9SAndroid Build Coastguard Worker                    0,
841*da0073e9SAndroid Build Coastguard Worker                    0,
842*da0073e9SAndroid Build Coastguard Worker                    6,
843*da0073e9SAndroid Build Coastguard Worker                    6,
844*da0073e9SAndroid Build Coastguard Worker                ),
845*da0073e9SAndroid Build Coastguard Worker                2,
846*da0073e9SAndroid Build Coastguard Worker                dtype=torch.int64,
847*da0073e9SAndroid Build Coastguard Worker                device=device,
848*da0073e9SAndroid Build Coastguard Worker            )
849*da0073e9SAndroid Build Coastguard Worker            offsets = torch.full((2, 0, 0, 6, 6), 0, dtype=torch.int64, device=device)
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker            if i == 0:
852*da0073e9SAndroid Build Coastguard Worker                error_msg = "input has to be 1D or 2D Tensor"
853*da0073e9SAndroid Build Coastguard Worker            else:
854*da0073e9SAndroid Build Coastguard Worker                error_msg = "input has to be a 1D or 2D Tensor"
855*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.disable(self.assertRaisesRegex)(
856*da0073e9SAndroid Build Coastguard Worker                err_type, error_msg, lambda: f(weight, indices, offsets)
857*da0073e9SAndroid Build Coastguard Worker            )
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker            weight = torch.full((2, 2), 0, dtype=torch.float64, device=device)
860*da0073e9SAndroid Build Coastguard Worker            indices = torch.full((2,), 1, dtype=torch.int64, device=device)
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.disable(self.assertRaisesRegex)(
863*da0073e9SAndroid Build Coastguard Worker                err_type,
864*da0073e9SAndroid Build Coastguard Worker                "offsets has to be a 1D Tensor",
865*da0073e9SAndroid Build Coastguard Worker                lambda: f(weight, indices, offsets),
866*da0073e9SAndroid Build Coastguard Worker            )
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker            weight = torch.full((2, 2, 2), 0, dtype=torch.float64, device=device)
869*da0073e9SAndroid Build Coastguard Worker            indices = torch.full((2,), 2, dtype=torch.int64, device=device)
870*da0073e9SAndroid Build Coastguard Worker            offsets = torch.full((2,), 0, dtype=torch.int64, device=device)
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.disable(self.assertRaisesRegex)(
873*da0073e9SAndroid Build Coastguard Worker                err_type,
874*da0073e9SAndroid Build Coastguard Worker                "weight has to be a 2D Tensor",
875*da0073e9SAndroid Build Coastguard Worker                lambda: f(weight, indices, offsets),
876*da0073e9SAndroid Build Coastguard Worker            )
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
879*da0073e9SAndroid Build Coastguard Worker    def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes):
880*da0073e9SAndroid Build Coastguard Worker        # Failure 1: mismatched embeddings / per_sample_weights dtype
881*da0073e9SAndroid Build Coastguard Worker        es = nn.EmbeddingBag(5, 2, mode="sum").to(dtype=torch.float, device=device)
882*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
883*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
884*da0073e9SAndroid Build Coastguard Worker        per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
885*da0073e9SAndroid Build Coastguard Worker        if device == "cpu":
886*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "have the same type as"):
887*da0073e9SAndroid Build Coastguard Worker                es(input, offsets, per_sample_weights)
888*da0073e9SAndroid Build Coastguard Worker        else:
889*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "expected scalar type"):
890*da0073e9SAndroid Build Coastguard Worker                es(input, offsets, per_sample_weights)
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        # Failure 2.1: input/per_sample_weights have different sizes (1d input)
893*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
894*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
895*da0073e9SAndroid Build Coastguard Worker        per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
896*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "same shape as the input"):
897*da0073e9SAndroid Build Coastguard Worker            es(input, offsets, per_sample_weights)
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker        # Failure 2.2: input/per_sample_weights have different sizes (2d input)
900*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
901*da0073e9SAndroid Build Coastguard Worker        offsets = None
902*da0073e9SAndroid Build Coastguard Worker        per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
903*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "same shape as the input"):
904*da0073e9SAndroid Build Coastguard Worker            es(input, offsets, per_sample_weights)
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker        # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean')
907*da0073e9SAndroid Build Coastguard Worker        for unsupported_mode in ("max", "mean"):
908*da0073e9SAndroid Build Coastguard Worker            es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
909*da0073e9SAndroid Build Coastguard Worker                dtype=torch.float, device=device
910*da0073e9SAndroid Build Coastguard Worker            )
911*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
912*da0073e9SAndroid Build Coastguard Worker            offsets = None
913*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
914*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
915*da0073e9SAndroid Build Coastguard Worker                NotImplementedError, "only supported for mode='sum'"
916*da0073e9SAndroid Build Coastguard Worker            ):
917*da0073e9SAndroid Build Coastguard Worker                es(input, offsets, per_sample_weights)
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker    def _embedding_bag_reference_impl(
920*da0073e9SAndroid Build Coastguard Worker        self,
921*da0073e9SAndroid Build Coastguard Worker        input,
922*da0073e9SAndroid Build Coastguard Worker        weight,
923*da0073e9SAndroid Build Coastguard Worker        offsets=None,
924*da0073e9SAndroid Build Coastguard Worker        mode="sum",
925*da0073e9SAndroid Build Coastguard Worker        per_sample_weights=None,
926*da0073e9SAndroid Build Coastguard Worker        include_last_offset=False,
927*da0073e9SAndroid Build Coastguard Worker    ):
928*da0073e9SAndroid Build Coastguard Worker        assert mode == "sum" or per_sample_weights is None
929*da0073e9SAndroid Build Coastguard Worker        assert offsets is not None
930*da0073e9SAndroid Build Coastguard Worker        if per_sample_weights is None:
931*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = torch.ones(input.size()).to(
932*da0073e9SAndroid Build Coastguard Worker                dtype=weight.dtype, device=weight.device
933*da0073e9SAndroid Build Coastguard Worker            )
934*da0073e9SAndroid Build Coastguard Worker        assert input.numel() == per_sample_weights.numel()
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        bags = []
937*da0073e9SAndroid Build Coastguard Worker        long_input = input.to(torch.long)
938*da0073e9SAndroid Build Coastguard Worker        embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(
939*da0073e9SAndroid Build Coastguard Worker            1
940*da0073e9SAndroid Build Coastguard Worker        )
941*da0073e9SAndroid Build Coastguard Worker        if include_last_offset:
942*da0073e9SAndroid Build Coastguard Worker            for index in range(len(offsets) - 1):
943*da0073e9SAndroid Build Coastguard Worker                offset = offsets[index]
944*da0073e9SAndroid Build Coastguard Worker                next_offset = offsets[index + 1]
945*da0073e9SAndroid Build Coastguard Worker                length = next_offset - offset
946*da0073e9SAndroid Build Coastguard Worker                if length == 0:
947*da0073e9SAndroid Build Coastguard Worker                    bags.append(
948*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([0] * weight.size(1)).to(
949*da0073e9SAndroid Build Coastguard Worker                            dtype=embeddings.dtype, device=embeddings.device
950*da0073e9SAndroid Build Coastguard Worker                        )
951*da0073e9SAndroid Build Coastguard Worker                    )
952*da0073e9SAndroid Build Coastguard Worker                else:
953*da0073e9SAndroid Build Coastguard Worker                    if mode == "sum":
954*da0073e9SAndroid Build Coastguard Worker                        bags.append(embeddings.narrow(0, offset, length).sum(0))
955*da0073e9SAndroid Build Coastguard Worker                    elif mode == "mean":
956*da0073e9SAndroid Build Coastguard Worker                        bags.append(
957*da0073e9SAndroid Build Coastguard Worker                            embeddings.narrow(0, offset, length).sum(0).div(length)
958*da0073e9SAndroid Build Coastguard Worker                        )
959*da0073e9SAndroid Build Coastguard Worker                    else:
960*da0073e9SAndroid Build Coastguard Worker                        assert mode == "max"
961*da0073e9SAndroid Build Coastguard Worker                        bags.append(embeddings.narrow(0, offset, length).max(0)[0])
962*da0073e9SAndroid Build Coastguard Worker        else:
963*da0073e9SAndroid Build Coastguard Worker            for index, offset in enumerate(offsets):
964*da0073e9SAndroid Build Coastguard Worker                if index + 1 < len(offsets):
965*da0073e9SAndroid Build Coastguard Worker                    next_offset = offsets[index + 1]
966*da0073e9SAndroid Build Coastguard Worker                else:
967*da0073e9SAndroid Build Coastguard Worker                    next_offset = len(long_input)
968*da0073e9SAndroid Build Coastguard Worker                length = next_offset - offset
969*da0073e9SAndroid Build Coastguard Worker                if length == 0:
970*da0073e9SAndroid Build Coastguard Worker                    bags.append(
971*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([0] * weight.size(1)).to(
972*da0073e9SAndroid Build Coastguard Worker                            dtype=embeddings.dtype, device=embeddings.device
973*da0073e9SAndroid Build Coastguard Worker                        )
974*da0073e9SAndroid Build Coastguard Worker                    )
975*da0073e9SAndroid Build Coastguard Worker                else:
976*da0073e9SAndroid Build Coastguard Worker                    if mode == "sum":
977*da0073e9SAndroid Build Coastguard Worker                        bags.append(embeddings.narrow(0, offset, length).sum(0))
978*da0073e9SAndroid Build Coastguard Worker                    elif mode == "mean":
979*da0073e9SAndroid Build Coastguard Worker                        bags.append(
980*da0073e9SAndroid Build Coastguard Worker                            embeddings.narrow(0, offset, length).sum(0).div(length)
981*da0073e9SAndroid Build Coastguard Worker                        )
982*da0073e9SAndroid Build Coastguard Worker                    else:
983*da0073e9SAndroid Build Coastguard Worker                        assert mode == "max"
984*da0073e9SAndroid Build Coastguard Worker                        bags.append(embeddings.narrow(0, offset, length).max(0)[0])
985*da0073e9SAndroid Build Coastguard Worker        return torch.stack(bags)
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker    @skipMeta
988*da0073e9SAndroid Build Coastguard Worker    @dtypes(
989*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
990*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
991*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
992*da0073e9SAndroid Build Coastguard Worker            (torch.half, torch.bfloat16, torch.float, torch.double),
993*da0073e9SAndroid Build Coastguard Worker        )
994*da0073e9SAndroid Build Coastguard Worker    )
995*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
996*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
997*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
998*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
999*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half),
1000*da0073e9SAndroid Build Coastguard Worker        )
1001*da0073e9SAndroid Build Coastguard Worker    )
1002*da0073e9SAndroid Build Coastguard Worker    def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
1003*da0073e9SAndroid Build Coastguard Worker        # Test empty input and per sample weight, and backward pass. There was a CUDA
1004*da0073e9SAndroid Build Coastguard Worker        # invalid configuration bug (more context in #46572)
1005*da0073e9SAndroid Build Coastguard Worker        def test_per_sample_weights(mode, trainable_scale):
1006*da0073e9SAndroid Build Coastguard Worker            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
1007*da0073e9SAndroid Build Coastguard Worker            es.weight.data.copy_(
1008*da0073e9SAndroid Build Coastguard Worker                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
1009*da0073e9SAndroid Build Coastguard Worker            )
1010*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor([], device=device, dtype=dtypes[0])
1011*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[1])
1012*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = torch.randn_like(
1013*da0073e9SAndroid Build Coastguard Worker                input, dtype=dtypes[2]
1014*da0073e9SAndroid Build Coastguard Worker            ).requires_grad_(trainable_scale)
1015*da0073e9SAndroid Build Coastguard Worker            ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
1016*da0073e9SAndroid Build Coastguard Worker                trainable_scale
1017*da0073e9SAndroid Build Coastguard Worker            )
1018*da0073e9SAndroid Build Coastguard Worker            reference_weights = es.weight.detach().requires_grad_()
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker            expected = self._embedding_bag_reference_impl(
1021*da0073e9SAndroid Build Coastguard Worker                input, reference_weights, offsets, mode, ref_per_sample_weights
1022*da0073e9SAndroid Build Coastguard Worker            )
1023*da0073e9SAndroid Build Coastguard Worker            result = es(input, offsets, per_sample_weights)
1024*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1025*da0073e9SAndroid Build Coastguard Worker                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
1026*da0073e9SAndroid Build Coastguard Worker            )
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn_like(expected)
1029*da0073e9SAndroid Build Coastguard Worker            result.backward(grad)
1030*da0073e9SAndroid Build Coastguard Worker            # the reference impl doesn't have grad fn for empty input; but the grad should
1031*da0073e9SAndroid Build Coastguard Worker            # simply be a zero tensor
1032*da0073e9SAndroid Build Coastguard Worker            ref_weights_grad = torch.zeros_like(es.weight)
1033*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1034*da0073e9SAndroid Build Coastguard Worker                es.weight.grad,
1035*da0073e9SAndroid Build Coastguard Worker                ref_weights_grad,
1036*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtypes[2]],
1037*da0073e9SAndroid Build Coastguard Worker                rtol=0,
1038*da0073e9SAndroid Build Coastguard Worker            )
1039*da0073e9SAndroid Build Coastguard Worker            if trainable_scale:
1040*da0073e9SAndroid Build Coastguard Worker                ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
1041*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1042*da0073e9SAndroid Build Coastguard Worker                    per_sample_weights.grad,
1043*da0073e9SAndroid Build Coastguard Worker                    ref_per_sample_weights_grad,
1044*da0073e9SAndroid Build Coastguard Worker                    atol=dtype2prec_DONTUSE[dtypes[2]],
1045*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
1046*da0073e9SAndroid Build Coastguard Worker                )
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        modes = ("sum",)
1049*da0073e9SAndroid Build Coastguard Worker        trainable_scale = (True, False)
1050*da0073e9SAndroid Build Coastguard Worker        for mode, trainable in itertools.product(modes, trainable_scale):
1051*da0073e9SAndroid Build Coastguard Worker            test_per_sample_weights(mode, trainable)
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1054*da0073e9SAndroid Build Coastguard Worker    @dtypes(
1055*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1056*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1057*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1058*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half, torch.bfloat16),
1059*da0073e9SAndroid Build Coastguard Worker        )
1060*da0073e9SAndroid Build Coastguard Worker    )
1061*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
1062*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1063*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1064*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1065*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half),
1066*da0073e9SAndroid Build Coastguard Worker        )
1067*da0073e9SAndroid Build Coastguard Worker    )
1068*da0073e9SAndroid Build Coastguard Worker    def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
1069*da0073e9SAndroid Build Coastguard Worker        def test_per_sample_weights(mode, trainable_scale):
1070*da0073e9SAndroid Build Coastguard Worker            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
1071*da0073e9SAndroid Build Coastguard Worker            es.weight.data.copy_(
1072*da0073e9SAndroid Build Coastguard Worker                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
1073*da0073e9SAndroid Build Coastguard Worker            )
1074*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
1075*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
1076*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = torch.randn_like(
1077*da0073e9SAndroid Build Coastguard Worker                input, dtype=dtypes[2]
1078*da0073e9SAndroid Build Coastguard Worker            ).requires_grad_(trainable_scale)
1079*da0073e9SAndroid Build Coastguard Worker            ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
1080*da0073e9SAndroid Build Coastguard Worker                trainable_scale
1081*da0073e9SAndroid Build Coastguard Worker            )
1082*da0073e9SAndroid Build Coastguard Worker            reference_weights = es.weight.detach().requires_grad_()
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker            expected = self._embedding_bag_reference_impl(
1085*da0073e9SAndroid Build Coastguard Worker                input, reference_weights, offsets, mode, ref_per_sample_weights
1086*da0073e9SAndroid Build Coastguard Worker            )
1087*da0073e9SAndroid Build Coastguard Worker            result = es(input, offsets, per_sample_weights)
1088*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1089*da0073e9SAndroid Build Coastguard Worker                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
1090*da0073e9SAndroid Build Coastguard Worker            )
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn_like(expected).to(dtype=dtypes[2], device=device)
1093*da0073e9SAndroid Build Coastguard Worker            result.backward(grad)
1094*da0073e9SAndroid Build Coastguard Worker            expected.backward(grad)
1095*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1096*da0073e9SAndroid Build Coastguard Worker                es.weight.grad,
1097*da0073e9SAndroid Build Coastguard Worker                reference_weights.grad,
1098*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtypes[2]],
1099*da0073e9SAndroid Build Coastguard Worker                rtol=0,
1100*da0073e9SAndroid Build Coastguard Worker            )
1101*da0073e9SAndroid Build Coastguard Worker            if trainable_scale:
1102*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1103*da0073e9SAndroid Build Coastguard Worker                    per_sample_weights.grad,
1104*da0073e9SAndroid Build Coastguard Worker                    ref_per_sample_weights.grad,
1105*da0073e9SAndroid Build Coastguard Worker                    atol=dtype2prec_DONTUSE[dtypes[2]],
1106*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
1107*da0073e9SAndroid Build Coastguard Worker                )
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker        modes = ("sum",)
1110*da0073e9SAndroid Build Coastguard Worker        trainable_scale = (True, False)
1111*da0073e9SAndroid Build Coastguard Worker        for mode, trainable in itertools.product(modes, trainable_scale):
1112*da0073e9SAndroid Build Coastguard Worker            test_per_sample_weights(mode, trainable)
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1115*da0073e9SAndroid Build Coastguard Worker    @dtypes(
1116*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1117*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1118*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1119*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half, torch.bfloat16),
1120*da0073e9SAndroid Build Coastguard Worker        )
1121*da0073e9SAndroid Build Coastguard Worker    )
1122*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
1123*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1124*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1125*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1126*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half),
1127*da0073e9SAndroid Build Coastguard Worker        )
1128*da0073e9SAndroid Build Coastguard Worker    )
1129*da0073e9SAndroid Build Coastguard Worker    def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
1130*da0073e9SAndroid Build Coastguard Worker        def test_per_sample_weights_new_offsets(
1131*da0073e9SAndroid Build Coastguard Worker            mode, trainable_scale, include_last_offset, has_weight=True
1132*da0073e9SAndroid Build Coastguard Worker        ):
1133*da0073e9SAndroid Build Coastguard Worker            es = nn.EmbeddingBag(
1134*da0073e9SAndroid Build Coastguard Worker                5, 2, mode=mode, include_last_offset=include_last_offset
1135*da0073e9SAndroid Build Coastguard Worker            ).to(dtype=dtypes[2], device=device)
1136*da0073e9SAndroid Build Coastguard Worker            es.weight.data.copy_(
1137*da0073e9SAndroid Build Coastguard Worker                torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2])
1138*da0073e9SAndroid Build Coastguard Worker            )
1139*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
1140*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker            if include_last_offset:
1143*da0073e9SAndroid Build Coastguard Worker                offsets = torch.cat(
1144*da0073e9SAndroid Build Coastguard Worker                    (
1145*da0073e9SAndroid Build Coastguard Worker                        offsets,
1146*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([input.size(0)], device=device, dtype=dtypes[1]),
1147*da0073e9SAndroid Build Coastguard Worker                    ),
1148*da0073e9SAndroid Build Coastguard Worker                    0,
1149*da0073e9SAndroid Build Coastguard Worker                )
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker            if has_weight:
1152*da0073e9SAndroid Build Coastguard Worker                per_sample_weights = torch.randn_like(
1153*da0073e9SAndroid Build Coastguard Worker                    input, device=device, dtype=dtypes[2]
1154*da0073e9SAndroid Build Coastguard Worker                ).requires_grad_(trainable_scale)
1155*da0073e9SAndroid Build Coastguard Worker                ref_per_sample_weights = per_sample_weights.detach().requires_grad_(
1156*da0073e9SAndroid Build Coastguard Worker                    trainable_scale
1157*da0073e9SAndroid Build Coastguard Worker                )
1158*da0073e9SAndroid Build Coastguard Worker            else:
1159*da0073e9SAndroid Build Coastguard Worker                per_sample_weights = None
1160*da0073e9SAndroid Build Coastguard Worker                ref_per_sample_weights = None
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker            reference_weights = es.weight.detach().requires_grad_()
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker            expected = self._embedding_bag_reference_impl(
1165*da0073e9SAndroid Build Coastguard Worker                input,
1166*da0073e9SAndroid Build Coastguard Worker                reference_weights,
1167*da0073e9SAndroid Build Coastguard Worker                offsets,
1168*da0073e9SAndroid Build Coastguard Worker                mode,
1169*da0073e9SAndroid Build Coastguard Worker                ref_per_sample_weights,
1170*da0073e9SAndroid Build Coastguard Worker                include_last_offset,
1171*da0073e9SAndroid Build Coastguard Worker            )
1172*da0073e9SAndroid Build Coastguard Worker            result = es(input, offsets, per_sample_weights)
1173*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1174*da0073e9SAndroid Build Coastguard Worker                result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0
1175*da0073e9SAndroid Build Coastguard Worker            )
1176*da0073e9SAndroid Build Coastguard Worker
1177*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn_like(expected)
1178*da0073e9SAndroid Build Coastguard Worker            result.backward(grad)
1179*da0073e9SAndroid Build Coastguard Worker            expected.backward(grad)
1180*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1181*da0073e9SAndroid Build Coastguard Worker                es.weight.grad,
1182*da0073e9SAndroid Build Coastguard Worker                reference_weights.grad,
1183*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[dtypes[2]],
1184*da0073e9SAndroid Build Coastguard Worker                rtol=0,
1185*da0073e9SAndroid Build Coastguard Worker            )
1186*da0073e9SAndroid Build Coastguard Worker            if has_weight and trainable_scale:
1187*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1188*da0073e9SAndroid Build Coastguard Worker                    per_sample_weights.grad,
1189*da0073e9SAndroid Build Coastguard Worker                    ref_per_sample_weights.grad,
1190*da0073e9SAndroid Build Coastguard Worker                    atol=dtype2prec_DONTUSE[dtypes[2]],
1191*da0073e9SAndroid Build Coastguard Worker                    rtol=0,
1192*da0073e9SAndroid Build Coastguard Worker                )
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        trainable_scale = (True, False)
1195*da0073e9SAndroid Build Coastguard Worker        include_last_offset_list = (True, False)
1196*da0073e9SAndroid Build Coastguard Worker        modes = (("sum", False), ("sum", True), ("max", False), ("mean", False))
1197*da0073e9SAndroid Build Coastguard Worker        for (mode, has_weight), trainable, include_last_offset in itertools.product(
1198*da0073e9SAndroid Build Coastguard Worker            modes, trainable_scale, include_last_offset_list
1199*da0073e9SAndroid Build Coastguard Worker        ):
1200*da0073e9SAndroid Build Coastguard Worker            test_per_sample_weights_new_offsets(
1201*da0073e9SAndroid Build Coastguard Worker                mode, trainable, include_last_offset, has_weight
1202*da0073e9SAndroid Build Coastguard Worker            )
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    def _test_EmbeddingBag_vs_Embedding(
1205*da0073e9SAndroid Build Coastguard Worker        self,
1206*da0073e9SAndroid Build Coastguard Worker        N,
1207*da0073e9SAndroid Build Coastguard Worker        D,
1208*da0073e9SAndroid Build Coastguard Worker        B,
1209*da0073e9SAndroid Build Coastguard Worker        L,
1210*da0073e9SAndroid Build Coastguard Worker        max_norm=None,
1211*da0073e9SAndroid Build Coastguard Worker        mode="mean",
1212*da0073e9SAndroid Build Coastguard Worker        device="cpu",
1213*da0073e9SAndroid Build Coastguard Worker        wdtype=torch.float,
1214*da0073e9SAndroid Build Coastguard Worker        dtype=torch.long,
1215*da0073e9SAndroid Build Coastguard Worker        test_per_sample_weights=False,
1216*da0073e9SAndroid Build Coastguard Worker        trainable_per_sample_weights=False,
1217*da0073e9SAndroid Build Coastguard Worker        sparse=False,
1218*da0073e9SAndroid Build Coastguard Worker        test_backward=True,
1219*da0073e9SAndroid Build Coastguard Worker        backward_prec=None,
1220*da0073e9SAndroid Build Coastguard Worker    ):
1221*da0073e9SAndroid Build Coastguard Worker        es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(
1222*da0073e9SAndroid Build Coastguard Worker            device, wdtype
1223*da0073e9SAndroid Build Coastguard Worker        )
1224*da0073e9SAndroid Build Coastguard Worker        e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype)
1225*da0073e9SAndroid Build Coastguard Worker        e.weight.data.copy_(es.weight)
1226*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(N, (B, L), device=device, dtype=dtype)
1227*da0073e9SAndroid Build Coastguard Worker        offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L)
1228*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.rand(B, D, device=device, dtype=wdtype)
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker        if test_per_sample_weights:
1231*da0073e9SAndroid Build Coastguard Worker            # To prevent large gradients, weights should sum to 1 for each bag
1232*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = torch.randn(B, L, device=device, dtype=wdtype).softmax(
1233*da0073e9SAndroid Build Coastguard Worker                dim=-1
1234*da0073e9SAndroid Build Coastguard Worker            )
1235*da0073e9SAndroid Build Coastguard Worker            per_sample_weights_reference = per_sample_weights.clone().requires_grad_(
1236*da0073e9SAndroid Build Coastguard Worker                trainable_per_sample_weights
1237*da0073e9SAndroid Build Coastguard Worker            )
1238*da0073e9SAndroid Build Coastguard Worker            per_sample_weights.requires_grad_(trainable_per_sample_weights)
1239*da0073e9SAndroid Build Coastguard Worker            output = es(input.view(-1), offsets, per_sample_weights.view(-1))
1240*da0073e9SAndroid Build Coastguard Worker        else:
1241*da0073e9SAndroid Build Coastguard Worker            output = es(input.view(-1), offsets)
1242*da0073e9SAndroid Build Coastguard Worker            per_sample_weights = None
1243*da0073e9SAndroid Build Coastguard Worker            per_sample_weights_reference = None
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        if mode == "sum":
1246*da0073e9SAndroid Build Coastguard Worker            if test_per_sample_weights:
1247*da0073e9SAndroid Build Coastguard Worker                ref_output = (
1248*da0073e9SAndroid Build Coastguard Worker                    e(input) * per_sample_weights_reference.unsqueeze(-1)
1249*da0073e9SAndroid Build Coastguard Worker                ).sum(1)
1250*da0073e9SAndroid Build Coastguard Worker            else:
1251*da0073e9SAndroid Build Coastguard Worker                ref_output = e(input).sum(1)
1252*da0073e9SAndroid Build Coastguard Worker        elif mode == "mean":
1253*da0073e9SAndroid Build Coastguard Worker            assert not test_per_sample_weights
1254*da0073e9SAndroid Build Coastguard Worker            ref_output = e(input).mean(1)
1255*da0073e9SAndroid Build Coastguard Worker        elif mode == "max":
1256*da0073e9SAndroid Build Coastguard Worker            assert not test_per_sample_weights
1257*da0073e9SAndroid Build Coastguard Worker            ref_output = e(input).max(1)[0]
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
1260*da0073e9SAndroid Build Coastguard Worker
1261*da0073e9SAndroid Build Coastguard Worker        if not test_backward:
1262*da0073e9SAndroid Build Coastguard Worker            return
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_output)
1265*da0073e9SAndroid Build Coastguard Worker        ref_output.backward(grad_output)
1266*da0073e9SAndroid Build Coastguard Worker        es_weight_grad = es.weight.grad
1267*da0073e9SAndroid Build Coastguard Worker        if sparse:
1268*da0073e9SAndroid Build Coastguard Worker            es_weight_grad = es.weight.grad.to_dense()
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        # We have more floating point error here because we are dealing with larger numbers
1271*da0073e9SAndroid Build Coastguard Worker        if backward_prec is None:
1272*da0073e9SAndroid Build Coastguard Worker            needed_prec = dtype2prec_DONTUSE[wdtype] * 5
1273*da0073e9SAndroid Build Coastguard Worker            rtol = 0.02 if wdtype == torch.half else 0
1274*da0073e9SAndroid Build Coastguard Worker        else:
1275*da0073e9SAndroid Build Coastguard Worker            needed_prec = backward_prec
1276*da0073e9SAndroid Build Coastguard Worker            rtol = 0
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=rtol)
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker        if test_per_sample_weights and trainable_per_sample_weights:
1281*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1282*da0073e9SAndroid Build Coastguard Worker                per_sample_weights.grad,
1283*da0073e9SAndroid Build Coastguard Worker                per_sample_weights_reference.grad,
1284*da0073e9SAndroid Build Coastguard Worker                atol=dtype2prec_DONTUSE[wdtype],
1285*da0073e9SAndroid Build Coastguard Worker                rtol=0,
1286*da0073e9SAndroid Build Coastguard Worker            )
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
1289*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1290*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long), (torch.half, torch.float, torch.double)
1291*da0073e9SAndroid Build Coastguard Worker        )
1292*da0073e9SAndroid Build Coastguard Worker    )
1293*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
1294*da0073e9SAndroid Build Coastguard Worker    def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
1295*da0073e9SAndroid Build Coastguard Worker        def run_tests(mode, sparse, trainable_per_sample_weights):
1296*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(
1297*da0073e9SAndroid Build Coastguard Worker                test_per_sample_weights=True,
1298*da0073e9SAndroid Build Coastguard Worker                device=device,
1299*da0073e9SAndroid Build Coastguard Worker                mode=mode,
1300*da0073e9SAndroid Build Coastguard Worker                wdtype=dtypes[1],
1301*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1302*da0073e9SAndroid Build Coastguard Worker                sparse=sparse,
1303*da0073e9SAndroid Build Coastguard Worker                trainable_per_sample_weights=trainable_per_sample_weights,
1304*da0073e9SAndroid Build Coastguard Worker            )
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker            # Simple case
1307*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs)
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker            # B * L > 1000
1310*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs)
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker            # Large num_embedding
1313*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs)
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker            # Large embedding_dim
1316*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker        modes = ("sum",)
1319*da0073e9SAndroid Build Coastguard Worker        sparsity = (True, False)
1320*da0073e9SAndroid Build Coastguard Worker        trainable_scale = (True, False)
1321*da0073e9SAndroid Build Coastguard Worker        for mode, sparse, trainable_per_sample_weights in itertools.product(
1322*da0073e9SAndroid Build Coastguard Worker            modes, sparsity, trainable_scale
1323*da0073e9SAndroid Build Coastguard Worker        ):
1324*da0073e9SAndroid Build Coastguard Worker            run_tests(mode, sparse, trainable_per_sample_weights)
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker        # Test CUDA Dense on half precision
1327*da0073e9SAndroid Build Coastguard Worker        if device == "cuda":
1328*da0073e9SAndroid Build Coastguard Worker            modes = ("sum",)
1329*da0073e9SAndroid Build Coastguard Worker            sparsity = (False,)
1330*da0073e9SAndroid Build Coastguard Worker            trainable_scale = (True, False)
1331*da0073e9SAndroid Build Coastguard Worker            for mode, sparse, trainable_per_sample_weights in itertools.product(
1332*da0073e9SAndroid Build Coastguard Worker                modes, sparsity, trainable_scale
1333*da0073e9SAndroid Build Coastguard Worker            ):
1334*da0073e9SAndroid Build Coastguard Worker                run_tests(mode, sparse, trainable_per_sample_weights)
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker    def _test_EmbeddingBag(
1337*da0073e9SAndroid Build Coastguard Worker        self,
1338*da0073e9SAndroid Build Coastguard Worker        device,
1339*da0073e9SAndroid Build Coastguard Worker        mode,
1340*da0073e9SAndroid Build Coastguard Worker        sparse,
1341*da0073e9SAndroid Build Coastguard Worker        wdtype=torch.double,
1342*da0073e9SAndroid Build Coastguard Worker        dtype=torch.long,
1343*da0073e9SAndroid Build Coastguard Worker        odtype=torch.long,
1344*da0073e9SAndroid Build Coastguard Worker        test_backward=True,
1345*da0073e9SAndroid Build Coastguard Worker    ):
1346*da0073e9SAndroid Build Coastguard Worker        # check a known test example
1347*da0073e9SAndroid Build Coastguard Worker        es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype)
1348*da0073e9SAndroid Build Coastguard Worker        es.weight.data.copy_(
1349*da0073e9SAndroid Build Coastguard Worker            torch.arange(1, 11, device=device).view_as(es.weight).to(wdtype)
1350*da0073e9SAndroid Build Coastguard Worker        )
1351*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype)
1352*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=odtype)
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.tensor([1, 2, 3, 4], device=device, dtype=wdtype).view(2, 2)
1355*da0073e9SAndroid Build Coastguard Worker        grad_output_with_empty = torch.tensor(
1356*da0073e9SAndroid Build Coastguard Worker            [99, 99, 1, 2, 99, 99, 3, 4, 99, 99], device=device, dtype=wdtype
1357*da0073e9SAndroid Build Coastguard Worker        ).view(5, 2)
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker        if mode == "sum" or mode == "mean":
1360*da0073e9SAndroid Build Coastguard Worker            denominator = 1 if mode == "sum" else 3
1361*da0073e9SAndroid Build Coastguard Worker            expected_output = (
1362*da0073e9SAndroid Build Coastguard Worker                torch.tensor([[13, 16], [13, 16]], device=device, dtype=wdtype)
1363*da0073e9SAndroid Build Coastguard Worker                / denominator
1364*da0073e9SAndroid Build Coastguard Worker            )
1365*da0073e9SAndroid Build Coastguard Worker
1366*da0073e9SAndroid Build Coastguard Worker            expected_output_with_empty = (
1367*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
1368*da0073e9SAndroid Build Coastguard Worker                    [[0, 0], [13, 16], [0, 0], [13, 16], [0, 0]],
1369*da0073e9SAndroid Build Coastguard Worker                    device=device,
1370*da0073e9SAndroid Build Coastguard Worker                    dtype=wdtype,
1371*da0073e9SAndroid Build Coastguard Worker                )
1372*da0073e9SAndroid Build Coastguard Worker                / denominator
1373*da0073e9SAndroid Build Coastguard Worker            )
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker            expected_grad_weight = (
1376*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
1377*da0073e9SAndroid Build Coastguard Worker                    [[3, 4], [5, 8], [0, 0], [1, 2], [3, 4]],
1378*da0073e9SAndroid Build Coastguard Worker                    device=device,
1379*da0073e9SAndroid Build Coastguard Worker                    dtype=wdtype,
1380*da0073e9SAndroid Build Coastguard Worker                )
1381*da0073e9SAndroid Build Coastguard Worker                / denominator
1382*da0073e9SAndroid Build Coastguard Worker            )
1383*da0073e9SAndroid Build Coastguard Worker        elif mode == "max":
1384*da0073e9SAndroid Build Coastguard Worker            expected_output = torch.tensor(
1385*da0073e9SAndroid Build Coastguard Worker                [[7, 8], [9, 10]], device=device, dtype=wdtype
1386*da0073e9SAndroid Build Coastguard Worker            )
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker            expected_output_with_empty = torch.tensor(
1389*da0073e9SAndroid Build Coastguard Worker                [[0, 0], [7, 8], [0, 0], [9, 10], [0, 0]], device=device, dtype=wdtype
1390*da0073e9SAndroid Build Coastguard Worker            )
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker            expected_grad_weight = torch.tensor(
1393*da0073e9SAndroid Build Coastguard Worker                [[0, 0], [0, 0], [0, 0], [1, 2], [3, 4]], device=device, dtype=wdtype
1394*da0073e9SAndroid Build Coastguard Worker            )
1395*da0073e9SAndroid Build Coastguard Worker        output = es(input, offsets)
1396*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_output_with_empty)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        es_weight_grad = es.weight.grad
1399*da0073e9SAndroid Build Coastguard Worker        if sparse:
1400*da0073e9SAndroid Build Coastguard Worker            es_weight_grad = es.weight.grad.to_dense()
1401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, expected_output_with_empty)
1402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1403*da0073e9SAndroid Build Coastguard Worker            es_weight_grad,
1404*da0073e9SAndroid Build Coastguard Worker            expected_grad_weight,
1405*da0073e9SAndroid Build Coastguard Worker            atol=dtype2prec_DONTUSE[wdtype],
1406*da0073e9SAndroid Build Coastguard Worker            rtol=0,
1407*da0073e9SAndroid Build Coastguard Worker        )
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker        # check same example except as 2D (2 x 3)
1410*da0073e9SAndroid Build Coastguard Worker        input = input.view(2, -1)
1411*da0073e9SAndroid Build Coastguard Worker        es.zero_grad()
1412*da0073e9SAndroid Build Coastguard Worker        output = es(input)
1413*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_output)
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker        es_weight_grad = es.weight.grad
1416*da0073e9SAndroid Build Coastguard Worker        if sparse:
1417*da0073e9SAndroid Build Coastguard Worker            es_weight_grad = es.weight.grad.to_dense()
1418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, expected_output)
1419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1420*da0073e9SAndroid Build Coastguard Worker            es_weight_grad,
1421*da0073e9SAndroid Build Coastguard Worker            expected_grad_weight,
1422*da0073e9SAndroid Build Coastguard Worker            atol=dtype2prec_DONTUSE[wdtype],
1423*da0073e9SAndroid Build Coastguard Worker            rtol=0,
1424*da0073e9SAndroid Build Coastguard Worker        )
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        # test all empty bags
1427*da0073e9SAndroid Build Coastguard Worker        es.zero_grad()
1428*da0073e9SAndroid Build Coastguard Worker        inputs = torch.tensor([], dtype=dtype, device=device)
1429*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 0, 0, 0], dtype=odtype, device=device)
1430*da0073e9SAndroid Build Coastguard Worker        es(inputs, offsets).sum().backward()
1431*da0073e9SAndroid Build Coastguard Worker        dense_grad = es.weight.grad
1432*da0073e9SAndroid Build Coastguard Worker        if dense_grad.is_sparse:
1433*da0073e9SAndroid Build Coastguard Worker            dense_grad = dense_grad.to_dense()
1434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dense_grad, torch.zeros_like(es.weight))
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
1437*da0073e9SAndroid Build Coastguard Worker        N, D, B, L = (
1438*da0073e9SAndroid Build Coastguard Worker            random.randint(1, 100),
1439*da0073e9SAndroid Build Coastguard Worker            random.randint(1, 100),
1440*da0073e9SAndroid Build Coastguard Worker            random.randint(1, 50),
1441*da0073e9SAndroid Build Coastguard Worker            random.randint(1, 50),
1442*da0073e9SAndroid Build Coastguard Worker        )
1443*da0073e9SAndroid Build Coastguard Worker        kwargs = dict(
1444*da0073e9SAndroid Build Coastguard Worker            mode=mode,
1445*da0073e9SAndroid Build Coastguard Worker            sparse=sparse,
1446*da0073e9SAndroid Build Coastguard Worker            device=device,
1447*da0073e9SAndroid Build Coastguard Worker            wdtype=wdtype,
1448*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
1449*da0073e9SAndroid Build Coastguard Worker            test_backward=test_backward,
1450*da0073e9SAndroid Build Coastguard Worker        )
1451*da0073e9SAndroid Build Coastguard Worker        self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
1452*da0073e9SAndroid Build Coastguard Worker        for max_norm in (None, 3):
1453*da0073e9SAndroid Build Coastguard Worker            for p in itertools.product([1, 2], repeat=4):
1454*da0073e9SAndroid Build Coastguard Worker                self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs)
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker        # check that giving illegal input combos raises error
1457*da0073e9SAndroid Build Coastguard Worker        es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
1458*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(3, 4, dtype=dtype)
1459*da0073e9SAndroid Build Coastguard Worker        offset = torch.arange(0, 3, dtype=odtype)
1460*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input, offset))
1461*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input.view(-1)))
1462*da0073e9SAndroid Build Coastguard Worker        offset[0] = 1
1463*da0073e9SAndroid Build Coastguard Worker        if self.device_type == "cpu":
1464*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.disable(self.assertRaises)(
1465*da0073e9SAndroid Build Coastguard Worker                RuntimeError, lambda: es(input.view(-1), offset)
1466*da0073e9SAndroid Build Coastguard Worker            )
1467*da0073e9SAndroid Build Coastguard Worker            offset[0] = 0
1468*da0073e9SAndroid Build Coastguard Worker            offset[-1] = 100
1469*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.disable(self.assertRaises)(
1470*da0073e9SAndroid Build Coastguard Worker                RuntimeError, lambda: es(input.view(-1), offset)
1471*da0073e9SAndroid Build Coastguard Worker            )
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1474*da0073e9SAndroid Build Coastguard Worker    @dtypes(
1475*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1476*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1477*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1478*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half, torch.bfloat16),
1479*da0073e9SAndroid Build Coastguard Worker        )
1480*da0073e9SAndroid Build Coastguard Worker    )
1481*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
1482*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1483*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1484*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1485*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half),
1486*da0073e9SAndroid Build Coastguard Worker        )
1487*da0073e9SAndroid Build Coastguard Worker    )
1488*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_device(self, device, dtypes):
1489*da0073e9SAndroid Build Coastguard Worker        if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu":
1490*da0073e9SAndroid Build Coastguard Worker            self.skipTest("bfloat16 not supported with Jetson cpu")
1491*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
1492*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1493*da0073e9SAndroid Build Coastguard Worker                device,
1494*da0073e9SAndroid Build Coastguard Worker                "sum",
1495*da0073e9SAndroid Build Coastguard Worker                False,
1496*da0073e9SAndroid Build Coastguard Worker                wdtype=dtypes[2],
1497*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1498*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1499*da0073e9SAndroid Build Coastguard Worker            )
1500*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1501*da0073e9SAndroid Build Coastguard Worker                device,
1502*da0073e9SAndroid Build Coastguard Worker                "mean",
1503*da0073e9SAndroid Build Coastguard Worker                False,
1504*da0073e9SAndroid Build Coastguard Worker                wdtype=dtypes[2],
1505*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1506*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1507*da0073e9SAndroid Build Coastguard Worker            )
1508*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1509*da0073e9SAndroid Build Coastguard Worker                device,
1510*da0073e9SAndroid Build Coastguard Worker                "max",
1511*da0073e9SAndroid Build Coastguard Worker                False,
1512*da0073e9SAndroid Build Coastguard Worker                wdtype=dtypes[2],
1513*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1514*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1515*da0073e9SAndroid Build Coastguard Worker            )
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker            test_backward = False
1518*da0073e9SAndroid Build Coastguard Worker            if self.device_type == "cuda":
1519*da0073e9SAndroid Build Coastguard Worker                # see 'todo' in test_embedding_bag.
1520*da0073e9SAndroid Build Coastguard Worker                test_backward = dtypes[2] is not torch.float16
1521*da0073e9SAndroid Build Coastguard Worker            elif self.device_type == "cpu":
1522*da0073e9SAndroid Build Coastguard Worker                # TODO: figure out why precision on sparse embeddings isn't the
1523*da0073e9SAndroid Build Coastguard Worker                # same as for dense.
1524*da0073e9SAndroid Build Coastguard Worker                test_backward = (
1525*da0073e9SAndroid Build Coastguard Worker                    dtypes[2] is not torch.float and dtypes[2] is not torch.float16
1526*da0073e9SAndroid Build Coastguard Worker                )
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1529*da0073e9SAndroid Build Coastguard Worker                device,
1530*da0073e9SAndroid Build Coastguard Worker                "sum",
1531*da0073e9SAndroid Build Coastguard Worker                True,
1532*da0073e9SAndroid Build Coastguard Worker                wdtype=dtypes[2],
1533*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1534*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1535*da0073e9SAndroid Build Coastguard Worker                test_backward=test_backward,
1536*da0073e9SAndroid Build Coastguard Worker            )
1537*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1538*da0073e9SAndroid Build Coastguard Worker                device,
1539*da0073e9SAndroid Build Coastguard Worker                "mean",
1540*da0073e9SAndroid Build Coastguard Worker                True,
1541*da0073e9SAndroid Build Coastguard Worker                wdtype=dtypes[2],
1542*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1543*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1544*da0073e9SAndroid Build Coastguard Worker                test_backward=test_backward,
1545*da0073e9SAndroid Build Coastguard Worker            )
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker    @skipMeta
1548*da0073e9SAndroid Build Coastguard Worker    @dtypes(
1549*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1550*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1551*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1552*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half, torch.bfloat16),
1553*da0073e9SAndroid Build Coastguard Worker        )
1554*da0073e9SAndroid Build Coastguard Worker    )
1555*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
1556*da0073e9SAndroid Build Coastguard Worker        *itertools.product(
1557*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1558*da0073e9SAndroid Build Coastguard Worker            (torch.int, torch.long),
1559*da0073e9SAndroid Build Coastguard Worker            (torch.float, torch.double, torch.half),
1560*da0073e9SAndroid Build Coastguard Worker        )
1561*da0073e9SAndroid Build Coastguard Worker    )
1562*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
1563*da0073e9SAndroid Build Coastguard Worker        weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker        weight_tensor_non_contig = weight_tensor[
1566*da0073e9SAndroid Build Coastguard Worker            :, :3
1567*da0073e9SAndroid Build Coastguard Worker        ]  # This is non-contiguous strided.
1568*da0073e9SAndroid Build Coastguard Worker        weight_tensor_contig = (
1569*da0073e9SAndroid Build Coastguard Worker            weight_tensor_non_contig.clone().contiguous()
1570*da0073e9SAndroid Build Coastguard Worker        )  # Contig-strided.
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device)
1573*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2], dtype=dtypes[1], device=device)
1574*da0073e9SAndroid Build Coastguard Worker        for mode in ["sum", "mean", "max"]:
1575*da0073e9SAndroid Build Coastguard Worker            output_non_contig = F.embedding_bag(
1576*da0073e9SAndroid Build Coastguard Worker                input=index,
1577*da0073e9SAndroid Build Coastguard Worker                weight=weight_tensor_non_contig,
1578*da0073e9SAndroid Build Coastguard Worker                offsets=offsets,
1579*da0073e9SAndroid Build Coastguard Worker                mode=mode,
1580*da0073e9SAndroid Build Coastguard Worker            )
1581*da0073e9SAndroid Build Coastguard Worker            output_contig = F.embedding_bag(
1582*da0073e9SAndroid Build Coastguard Worker                input=index,
1583*da0073e9SAndroid Build Coastguard Worker                weight=weight_tensor_contig,
1584*da0073e9SAndroid Build Coastguard Worker                offsets=offsets,
1585*da0073e9SAndroid Build Coastguard Worker                mode=mode,
1586*da0073e9SAndroid Build Coastguard Worker            )
1587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_non_contig, output_contig)
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # currently fails on XLA
1590*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
1591*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_bfloat16(self, device, dtypes):
1592*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
1593*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1594*da0073e9SAndroid Build Coastguard Worker                device,
1595*da0073e9SAndroid Build Coastguard Worker                "sum",
1596*da0073e9SAndroid Build Coastguard Worker                True,
1597*da0073e9SAndroid Build Coastguard Worker                wdtype=torch.bfloat16,
1598*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1599*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1600*da0073e9SAndroid Build Coastguard Worker                test_backward=True,
1601*da0073e9SAndroid Build Coastguard Worker            )
1602*da0073e9SAndroid Build Coastguard Worker            self._test_EmbeddingBag(
1603*da0073e9SAndroid Build Coastguard Worker                device,
1604*da0073e9SAndroid Build Coastguard Worker                "mean",
1605*da0073e9SAndroid Build Coastguard Worker                True,
1606*da0073e9SAndroid Build Coastguard Worker                wdtype=torch.bfloat16,
1607*da0073e9SAndroid Build Coastguard Worker                dtype=dtypes[0],
1608*da0073e9SAndroid Build Coastguard Worker                odtype=dtypes[1],
1609*da0073e9SAndroid Build Coastguard Worker                test_backward=True,
1610*da0073e9SAndroid Build Coastguard Worker            )
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # currently fails on XLA
1613*da0073e9SAndroid Build Coastguard Worker    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
1614*da0073e9SAndroid Build Coastguard Worker    def test_embedding_bag_half(self, device, dtypes):
1615*da0073e9SAndroid Build Coastguard Worker        self._test_EmbeddingBag(
1616*da0073e9SAndroid Build Coastguard Worker            device,
1617*da0073e9SAndroid Build Coastguard Worker            "sum",
1618*da0073e9SAndroid Build Coastguard Worker            True,
1619*da0073e9SAndroid Build Coastguard Worker            wdtype=torch.float16,
1620*da0073e9SAndroid Build Coastguard Worker            dtype=dtypes[0],
1621*da0073e9SAndroid Build Coastguard Worker            odtype=dtypes[1],
1622*da0073e9SAndroid Build Coastguard Worker            test_backward=True,
1623*da0073e9SAndroid Build Coastguard Worker        )
1624*da0073e9SAndroid Build Coastguard Worker
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestEmbeddingNNDeviceType, globals())
1627*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestEmbeddingNN)
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1630*da0073e9SAndroid Build Coastguard Worker    run_tests()
1631