1# Owner(s): ["module: nn"] 2import itertools 3import random 4import unittest 5from itertools import product 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.testing._internal.common_cuda import TEST_CUDA 11from torch.testing._internal.common_device_type import ( 12 dtypes, 13 dtypesIfCUDA, 14 instantiate_device_type_tests, 15 largeTensorTest, 16 onlyCUDA, 17 onlyNativeDeviceTypes, 18 skipCUDAIf, 19 skipMeta, 20 TEST_WITH_ROCM, 21) 22 23from torch.testing._internal.common_nn import NNTestCase 24from torch.testing._internal.common_utils import ( 25 _assertGradAndGradgradChecks, 26 dtype2prec_DONTUSE, 27 instantiate_parametrized_tests, 28 IS_JETSON, 29 parametrize as parametrize_test, 30 run_tests, 31 set_default_dtype, 32 skipIfTorchDynamo, 33) 34 35 36class TestEmbeddingNN(NNTestCase): 37 _do_cuda_memory_leak_check = True 38 _do_cuda_non_default_stream = True 39 40 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 41 def test_embedding_max_norm_unsorted_repeating_indices(self): 42 def create_embedding(device): 43 # Seed RNG so we get the same Embedding each time 44 torch.manual_seed(0) 45 return torch.nn.Embedding( 46 num_embeddings=20, embedding_dim=64, max_norm=1.0 47 ).to(device) 48 49 ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000) 50 out_cpu = create_embedding("cpu")(ix) 51 52 ix = ix.to("cuda") 53 out = create_embedding("cuda")(ix) 54 self.assertEqual(out.cpu(), out_cpu) 55 56 def test_embedding_sparse_basic(self): 57 embedding = nn.Embedding(10, 20, sparse=True) 58 input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long) 59 embedding(input).sum().backward() 60 self.assertTrue(embedding.weight.grad.is_sparse) 61 self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) 62 63 def test_embedding_sparse_empty_tensor(self): 64 embedding = nn.Embedding(0, 0, sparse=True) 65 input = torch.tensor([], dtype=torch.int64) 66 embedding(input).sum().backward() 67 self.assertTrue(embedding.weight.grad.is_sparse) 68 self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) 69 70 embedding = nn.Embedding(10, 0, sparse=True) 71 input = torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]) 72 embedding(input).sum().backward() 73 self.assertTrue(embedding.weight.grad.is_sparse) 74 self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) 75 76 def test_move_sparse_half_embedding(self): 77 embedding = nn.Embedding(10, 3, sparse=True) 78 self.assertEqual(embedding.weight.device.type, "cpu") 79 self.assertEqual(embedding.weight.dtype, torch.get_default_dtype()) 80 embedding.to(torch.float16) 81 self.assertEqual(embedding.weight.dtype, torch.float16) 82 self.assertEqual(embedding.embedding_dim, 3) 83 self.assertEqual(embedding.num_embeddings, 10) 84 85 if torch.cuda.is_available(): 86 embedding.to("cuda") 87 self.assertEqual(embedding.weight.device.type, "cuda") 88 embedding.to("cpu") 89 self.assertEqual(embedding.weight.device.type, "cpu") 90 91 def test_embedding_max_norm(self): 92 embedding = nn.Embedding(22, 5, max_norm=1.0) 93 input = torch.tensor([2, 8, 8, 6], dtype=torch.long) 94 output = embedding(input) 95 self.assertEqual(output[1], output[2]) 96 self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) 97 98 @parametrize_test( 99 "dtype", 100 ( 101 torch.uint8, 102 torch.int8, 103 torch.int16, 104 torch.int32, 105 torch.int64, 106 torch.float, 107 torch.double, 108 ), 109 ) 110 def test_embedding_from_pretrained(self, dtype): 111 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 112 embedding = nn.Embedding.from_pretrained(a) 113 self.assertEqual(a, embedding.weight.data) 114 115 input = torch.LongTensor([0, 1]) 116 output = embedding(input) 117 self.assertEqual(a, output) 118 119 def test_embedding_bag_from_pretrained(self): 120 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 121 embedding = nn.EmbeddingBag.from_pretrained(a) 122 self.assertEqual(a, embedding.weight) 123 124 input = torch.tensor([0, 1], dtype=torch.long) 125 output = embedding(input, torch.arange(input.size(0))) 126 self.assertEqual(a, output) 127 128 def test_embedding_from_pretrained_padding_idx(self): 129 padding_idx = 2 130 padding_vec = torch.ones(3) * 7 131 embeddings = torch.rand(4, 3, requires_grad=True) 132 with torch.no_grad(): 133 embeddings[padding_idx] = padding_vec 134 embedding_nn = nn.Embedding.from_pretrained(embeddings, padding_idx=padding_idx) 135 self.assertEqual(embedding_nn.weight[padding_idx], padding_vec) 136 137 def test_embedding_bag_from_pretrained_padding_idx(self): 138 padding_idx = 2 139 embeddings = torch.rand(4, 3, requires_grad=True) 140 embedding_nn = nn.EmbeddingBag.from_pretrained( 141 embeddings, padding_idx=padding_idx 142 ) 143 self.assertEqual(embedding_nn.weight, embeddings) 144 145 def test_embedding_from_pretrained_options(self): 146 with set_default_dtype(torch.double): 147 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 148 opts = { 149 "max_norm": 2.0, 150 "norm_type": 0.5, 151 "scale_grad_by_freq": False, 152 "sparse": True, 153 } 154 embedding = nn.Embedding.from_pretrained(a, **opts) 155 input = torch.LongTensor([0, 1]) 156 output = embedding(input) 157 # test output and that weight matrix was renormalized 158 self.assertEqual(a, output) 159 self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()) 160 self.assertTrue( 161 output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all() 162 ) 163 164 def test_embedding_functional(self): 165 a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long) 166 embeddings = torch.rand(4, 3, requires_grad=True) 167 168 embed_old = torch.nn.Embedding(4, 3) 169 embed_old.weight.data = embeddings.data 170 # A silly test for eager, this test is useful for when we run under PYTORCH_TEST_WITH_DYNAMO=1 171 # as it ensures that setattr correctly works. 172 self.assertEqual(embed_old.weight.data, embeddings.data) 173 res_old = embed_old(a) 174 175 res_F = F.embedding(a, embeddings) 176 self.assertEqual(res_old, res_F) 177 178 embed_old = torch.nn.Embedding(4, 3) 179 embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) 180 res_old = embed_old(a) 181 res_F = F.embedding(a, embeddings, padding_idx=2) 182 183 self.assertEqual(res_old, res_F) 184 185 # https://github.com/pytorch/pytorch/issues/130806 186 @largeTensorTest("40GB", device="cuda") 187 def test_large_tensors(self): 188 input = torch.randint(low=0, high=16032, size=[131072], device="cuda") 189 w = torch.randn([16032, 16384], device="cuda") 190 out = torch.nn.functional.embedding(input, w) 191 self.assertEqual(out.dim(), 2) 192 self.assertEqual(out.numel(), 2147483648) 193 194 def test_embedding_bag_functional(self): 195 a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long) 196 embeddings = torch.rand(4, 3, requires_grad=True) 197 198 embed_old = torch.nn.EmbeddingBag(4, 3) 199 embed_old.weight = torch.nn.Parameter(embeddings) 200 res_old = embed_old(a) 201 202 res_F = F.embedding_bag(a, embeddings) 203 self.assertEqual(res_old, res_F) 204 205 embed_old = torch.nn.EmbeddingBag(4, 3) 206 embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) 207 res_old = embed_old(a) 208 res_F = F.embedding_bag(a, embeddings, padding_idx=2) 209 210 self.assertEqual(res_old, res_F) 211 212 # Make sure that error is thrown if padding_idx is out of bounds 213 def test_embedding_bag_padding_idx_error(self): 214 a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long) 215 num_embeddings = 4 216 num_features = 3 217 embeddings = torch.rand(num_embeddings, num_features, requires_grad=True) 218 219 functional_err_msg = r"padding_idx must be within the number of embeddings" 220 module_err_msg = r"padding_idx must be within num_embeddings" 221 222 for padding_idx in range(-(num_embeddings + 2), (num_embeddings + 2)): 223 if (padding_idx < -num_embeddings) or (padding_idx >= num_embeddings): 224 with self.assertRaisesRegex(RuntimeError, functional_err_msg): 225 F.embedding_bag(a, embeddings, padding_idx=padding_idx) 226 with self.assertRaisesRegex(AssertionError, module_err_msg): 227 torch.nn.EmbeddingBag( 228 num_embeddings, num_features, padding_idx=padding_idx 229 ) 230 else: 231 F.embedding_bag(a, embeddings, padding_idx=padding_idx) 232 torch.nn.EmbeddingBag( 233 num_embeddings, num_features, padding_idx=padding_idx 234 ) 235 236 def test_embeddingbag_from_pretrained(self): 237 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 238 embeddingbag = nn.EmbeddingBag.from_pretrained(a) 239 self.assertEqual(a, embeddingbag.weight.data) 240 241 input = torch.LongTensor([[0, 1]]) 242 output = embeddingbag(input) 243 self.assertEqual(a.mean(0, keepdim=True), output) 244 245 def test_embeddingbag_from_pretrained_options(self): 246 with set_default_dtype(torch.double): 247 a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 248 opts = { 249 "max_norm": 2.0, 250 "norm_type": 0.5, 251 "scale_grad_by_freq": False, 252 "mode": "max", 253 "sparse": False, 254 } 255 embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts) 256 257 input = torch.LongTensor([[0, 1]]) 258 output = embeddingbag(input) 259 self.assertEqual(a.max(0, keepdim=True)[0], output) 260 self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()) 261 self.assertTrue( 262 a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all() 263 ) 264 265 def test_embeddingbag_include_last_offset(self): 266 # Test case from https://github.com/pytorch/pytorch/issues/89677 267 embeddingbag = nn.EmbeddingBag(100, 3, include_last_offset=True, padding_idx=61) 268 input = torch.tensor([0, 1, 2, 3]) 269 out = embeddingbag(input, torch.tensor([0, 3, 3])) 270 out2 = embeddingbag(input, torch.tensor([0, 3, 4])) 271 272 weight = embeddingbag.weight 273 row0 = weight[0:3].mean(0) 274 row1 = weight[3] 275 ref_out = torch.stack([row0, row1]) 276 277 self.assertEqual(ref_out, out) 278 self.assertEqual(ref_out, out2) 279 280 281class TestEmbeddingNNDeviceType(NNTestCase): 282 def test_embedding_dense_grad(self, device): 283 with set_default_dtype(torch.double): 284 embd = nn.Embedding(20, 20).to(device) 285 weight = embd.weight 286 287 def fn_wrapper(device): 288 def fn(weight): 289 inp = torch.tensor( 290 [[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long 291 ).to(device) 292 return torch.nn.functional.embedding(inp, weight) 293 294 return fn 295 296 fn = fn_wrapper(device) 297 _assertGradAndGradgradChecks(self, fn, (weight,)) 298 299 def test_embedding_scalar_weight_error(self, device): 300 indices = torch.rand(2, 2, device=device).long() 301 weights = [ 302 torch.tensor(1.0, device=device), 303 torch.tensor(1.0, device=device).reshape(1, 1, 1), 304 ] 305 306 for weight in weights: 307 with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"): 308 torch.nn.functional.embedding(indices, weight) 309 310 @dtypesIfCUDA(torch.float16, torch.float64) 311 @dtypes(torch.float64) 312 def test_embedding_backward(self, device, dtype): 313 embedding = nn.Embedding(10, 3, sparse=True) 314 tensor = torch.tensor([[7, 1, 3]]) 315 ones = torch.tensor(1.0, dtype=dtype).expand(3, 3) 316 tensorTwice = tensor.repeat(1, 2) 317 onesTwice = torch.cat((ones, ones)) 318 319 embedding = embedding.to(dtype=dtype).to(device) 320 tensor = tensor.to(device) 321 ones = ones.to(device) 322 tensorTwice = tensorTwice.to(device) 323 onesTwice = onesTwice.to(device) 324 325 embedding.zero_grad() 326 embedding(tensor[0]).sum().backward() 327 self.assertEqual(embedding.weight.grad._indices(), tensor) 328 self.assertEqual(embedding.weight.grad._values(), ones) 329 330 embedding.zero_grad() 331 embedding(tensor[0]).sum().backward() 332 embedding(tensor[0]).sum().backward() 333 self.assertEqual(embedding.weight.grad._indices(), tensorTwice) 334 self.assertEqual(embedding.weight.grad._values(), onesTwice) 335 336 embedding.zero_grad() 337 embedding(tensor[0]).sum().backward() 338 tensor[0, 0] = 8 339 embedding(tensor[0]).sum().backward() 340 tensorTwice[0, 3] = 8 341 self.assertEqual(embedding.weight.grad._indices(), tensorTwice) 342 self.assertEqual(embedding.weight.grad._values(), onesTwice) 343 344 @dtypesIfCUDA( 345 *( 346 (torch.float, torch.double, torch.bfloat16, torch.half) 347 if TEST_WITH_ROCM 348 else (torch.float, torch.double, torch.half) 349 ) 350 ) 351 @dtypes(torch.float32) 352 def test_embedding_max_norm_backward(self, device, dtype): 353 # can't use gradcheck since in place renorm makes analytical gradients different from produced ones 354 weight = torch.randn((4, 4), device=device, dtype=dtype) * 2 355 weight.requires_grad_() 356 inp_list = [0, 1, 2, 2] 357 inp = torch.tensor(inp_list, device=device) 358 out = nn.functional.embedding(inp, weight, max_norm=1.0).sum() 359 out.backward() 360 361 expected_grad = ( 362 torch.tensor([[1.0, 1.0, 2.0, 0.0]], device=device, dtype=dtype) 363 .transpose(0, 1) 364 .expand(4, 4) 365 ) 366 self.assertEqual(weight.grad, expected_grad) 367 368 @dtypesIfCUDA( 369 *( 370 (torch.float, torch.double, torch.bfloat16, torch.half) 371 if TEST_WITH_ROCM 372 else (torch.float, torch.double, torch.half) 373 ) 374 ) 375 @dtypes(torch.float32) 376 def test_embedding_max_norm_fwd_AD(self, device, dtype): 377 if torch.device(device).type == "xla": 378 self.skipTest("forward AD doesn't work on xla") 379 380 # can't use gradcheck since in place renorm makes analytical gradients different from produced ones 381 weight = torch.randn((4, 4), device=device, dtype=dtype) * 2 382 tangent = torch.ones((4, 4), device=device, dtype=dtype) 383 inp = torch.tensor([[0, 1], [2, 2]], device=device) 384 with torch.autograd.forward_ad.dual_level(): 385 dual_weight = torch.autograd.forward_ad.make_dual(weight, tangent) 386 out = nn.functional.embedding(inp, dual_weight, max_norm=1.0) 387 jvp = torch.autograd.forward_ad.unpack_dual(out).tangent 388 389 expected_grad = torch.ones((2, 2, 4), device=device, dtype=dtype) 390 self.assertEqual(jvp, expected_grad) 391 392 @dtypesIfCUDA( 393 *( 394 (torch.float, torch.double, torch.bfloat16, torch.half) 395 if TEST_WITH_ROCM 396 else (torch.float, torch.double, torch.half) 397 ) 398 ) 399 @dtypes(torch.float32) 400 def test_embedding_padding_idx(self, device, dtype): 401 embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype) 402 input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device) 403 output = embedding(input) 404 self.assertEqual(output[0][0].sum(), 0) 405 self.assertEqual(output[1][2].sum(), 0) 406 407 embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True).to(device, dtype) 408 input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device) 409 output = embedding(input) 410 self.assertEqual(output[0][0].sum(), 0) 411 self.assertEqual(output[1][2].sum(), 0) 412 413 # negative indexing check for padding_idx 414 # padding_idx=-2, num_embeddings=10 ==> index 8 padded 415 embedding = nn.Embedding(10, 20, padding_idx=-2).to(device, dtype) 416 input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device) 417 output = embedding(input) 418 self.assertEqual(output[0][2].sum(), 0) 419 self.assertEqual(output[1][1].sum(), 0) 420 421 embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True).to(device, dtype) 422 input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device) 423 output = embedding(input) 424 self.assertEqual(output[0][2].sum(), 0) 425 self.assertEqual(output[1][1].sum(), 0) 426 427 # change padding vector 428 padding_vector = torch.ones(20, dtype=dtype, device=device) 429 embedding = nn.Embedding(10, 20, padding_idx=2, sparse=True).to(device, dtype) 430 with torch.no_grad(): 431 embedding.weight[2] = padding_vector 432 input = torch.tensor([0, 2], dtype=torch.long).to(device) 433 output = embedding(input) 434 self.assertEqual(output[1], padding_vector) 435 436 # out of bounds check for padding_idx 437 self.assertRaises( 438 AssertionError, 439 nn.Embedding, 440 num_embeddings=10, 441 embedding_dim=20, 442 padding_idx=25, 443 ) 444 self.assertRaises( 445 AssertionError, 446 nn.Embedding, 447 num_embeddings=10, 448 embedding_dim=20, 449 padding_idx=-25, 450 ) 451 452 padding_idx = 0 453 embedding = nn.Embedding(5, 2, padding_idx=padding_idx).to(device, dtype) 454 for n in ( 455 1, 456 2, 457 1000, 458 ): # Need large N to trigger all the methods we have implemented 459 for other_indices in ([], [1, 3], [2]): 460 indices = torch.tensor( 461 other_indices + [padding_idx] * n, dtype=torch.long 462 ).to(device) 463 pre = embedding.weight[padding_idx].clone() 464 embedding(indices).sum().backward() 465 after = (embedding.weight + embedding.weight.grad)[padding_idx] 466 embedding.zero_grad() 467 self.assertEqual(after, pre) 468 469 # test double backward 470 emb_sum = embedding(indices).sum() 471 emb_grad = torch.autograd.grad( 472 outputs=emb_sum, 473 inputs=list(embedding.parameters()), 474 retain_graph=True, 475 ) 476 scalar = emb_grad[0].sum() + emb_sum 477 scalar.backward() 478 after = (embedding.weight + embedding.weight.grad)[padding_idx] 479 embedding.zero_grad() 480 self.assertEqual(after, pre) 481 482 # Check correctness of torch.nn.functional.embedding_bag forward and 483 # backward functions with padding_idx, given a 1D input separated into bags 484 # with an offset array. Compare against an equivalent 2D input that uses 485 # padding indices to fill in the gaps indicated by the offset array 486 487 @skipIfTorchDynamo("see https://github.com/pytorch/pytorch/pull/95621") 488 @onlyNativeDeviceTypes 489 @dtypes(torch.float32, torch.float64) 490 @dtypesIfCUDA(torch.half, torch.bfloat16) 491 def test_embedding_bag_1D_padding_idx(self, device, dtype): 492 num_features = 3 493 max_indices_per_bag = 10 494 num_bags = 10 495 num_words = 100 496 497 def gen_1D_indices_offsets(include_last_offset, allpad): 498 indices = [] 499 offsets = [] 500 cur_offset = 0 501 502 # Make one bag full and one bag empty, for extra coverage 503 empty_bag = random.randint(0, num_bags - 1) 504 full_bag = empty_bag 505 while full_bag == empty_bag: 506 full_bag = random.randint(0, num_bags - 1) 507 508 for bag in range(num_bags): 509 offsets.append(cur_offset) 510 if bag == full_bag: 511 bag_size = max_indices_per_bag 512 elif bag == empty_bag: 513 bag_size = 0 514 else: 515 bag_size = random.randint(1, max_indices_per_bag - 1) 516 indices += [ 517 1 if allpad else random.randint(0, num_words - 1) 518 for _ in range(bag_size) 519 ] 520 cur_offset += bag_size 521 522 # embedding_bag requires first entry of offsets to be 0 523 assert offsets[0] == 0 524 525 indices = torch.tensor(indices, device=device) 526 527 if include_last_offset: 528 offsets.append(indices.size(0)) 529 530 offsets = torch.tensor(offsets, device=device) 531 532 return indices, offsets 533 534 # Convert a 1-D indices-offsets representation into 2-D. Fill any empty 535 # indices with padding_idx 536 def gen_2D_indices_from_1D( 537 indices_1D, offsets, include_last_offset, padding_idx 538 ): 539 assert offsets[0] == 0 540 if include_last_offset: 541 offsets = offsets[:-1] 542 indices_2D = torch.empty( 543 num_bags, max_indices_per_bag, device=device, dtype=torch.long 544 ) 545 for bag in range(num_bags): 546 # Determine the start and end position of the bag within indices_1D 547 start = offsets[bag] 548 end = len(indices_1D) if bag + 1 == num_bags else offsets[bag + 1] 549 end = min(len(indices_1D), end) 550 551 # Pull out the bag's indices from indices_1D, and fill any 552 # remaining space with padding indices 553 indices_in_bag = [] 554 for item_pos in range(0, max_indices_per_bag): 555 if (start + item_pos) < end: 556 indices_in_bag.append(indices_1D[start + item_pos]) 557 else: 558 indices_in_bag.append(padding_idx) 559 indices_2D[bag] = torch.tensor(indices_in_bag, device=device) 560 561 return indices_2D 562 563 test_cases = product( 564 ["max", "mean", "sum"], [False, True], [False, True], [False, True] 565 ) 566 567 for mode, sparse, include_last_offset, allpad in test_cases: 568 # Max sparse and bfloat16 are not supported 569 if mode == "max": 570 if sparse or (dtype == torch.bfloat16): 571 continue 572 indices_1D, offsets = gen_1D_indices_offsets(include_last_offset, allpad) 573 for padding_idx_1D in list(set(indices_1D.tolist())) + [None]: 574 msg = ( 575 f"mode: '{mode}', sparse: {sparse}, include_last_offset: {include_last_offset}, " 576 f"padding_idx_1D: {padding_idx_1D}" 577 ) 578 579 # If 1D input does not use a padding index, we still need one for the 2D input, 580 # so we can add one dummy word to the weights to act as the padded word 581 padding_idx_2D = ( 582 padding_idx_1D if padding_idx_1D is not None else num_words 583 ) 584 num_words_with_padding = ( 585 num_words if padding_idx_1D is not None else num_words + 1 586 ) 587 588 indices_2D = gen_2D_indices_from_1D( 589 indices_1D, offsets, include_last_offset, padding_idx_2D 590 ) 591 592 weights = torch.randn( 593 num_words_with_padding, 594 num_features, 595 dtype=dtype, 596 device=device, 597 requires_grad=True, 598 ) 599 weights_check = weights.clone().detach().requires_grad_(True) 600 601 bag = torch.nn.functional.embedding_bag( 602 indices_1D, 603 weights, 604 offsets, 605 padding_idx=padding_idx_1D, 606 mode=mode, 607 sparse=sparse, 608 include_last_offset=include_last_offset, 609 ) 610 611 bag_check = torch.nn.functional.embedding_bag( 612 indices_2D, 613 weights_check, 614 padding_idx=padding_idx_2D, 615 mode=mode, 616 sparse=sparse, 617 ) 618 self.assertEqual(bag, bag_check, msg=msg) 619 620 bag.sum().backward() 621 bag_check.sum().backward() 622 623 # Sometimes, half dtype gradients mismatch by a greater amount 624 # than other dtypes 625 if dtype in [torch.half, torch.bfloat16]: 626 atol = 0.01 627 rtol = 0.01 628 else: 629 atol = None 630 rtol = None 631 self.assertEqual( 632 weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol 633 ) 634 635 # Check correctness of torch.nn.functional.embedding_bag forward and 636 # backward functions with padding_idx, given a 2D indices input. Compare 637 # against torch.nn.functional.embedding followed by a reduction. 638 @onlyNativeDeviceTypes 639 @dtypes(torch.float32, torch.float64) 640 @dtypesIfCUDA(torch.half, torch.bfloat16) 641 def test_embedding_bag_2D_padding_idx(self, device, dtype): 642 # Use a Python implementation of embedding_bag with padding_idx support 643 # to check torch.nn.functional.embedding_bag correctness 644 def embedding_bag_check(indices, weights, mode, sparse, padding_idx): 645 assert padding_idx is not None 646 embedding = torch.nn.functional.embedding( 647 indices, weights, padding_idx=padding_idx, sparse=sparse 648 ) 649 650 reduction_dim = indices.dim() - 1 651 652 if mode == "sum" or mode == "mean": 653 # We must avoid including elements at padding_idx in the 654 # sum/mean, so multiply those elements by 0, and multiply 655 # all other elements by 1 656 per_sample_weights = indices.ne(padding_idx).to(dtype).unsqueeze(-1) 657 res = embedding.mul(per_sample_weights).sum(dim=reduction_dim) 658 659 if mode == "mean": 660 weights_sum = per_sample_weights.sum(dim=reduction_dim) 661 res = res.div(weights_sum) 662 663 elif mode == "max": 664 # We must avoid allowing elements at padding_idx to be chosen 665 # as the max, so set those elements to negative infinity 666 res = embedding.masked_fill( 667 indices.unsqueeze(-1) == padding_idx, -float("inf") 668 ).amax(dim=reduction_dim) 669 670 else: 671 raise RuntimeError(f"mode '{mode}' is not available") 672 673 # If a row is all padding, set its corresponding result row to 0. 674 # This is needed because the above mean and max mode 675 # implementations set these elements to nan and -inf, respectively 676 if mode in ["mean", "max"]: 677 res = res.masked_fill( 678 indices.eq(padding_idx).all(dim=-1).unsqueeze(-1), 0 679 ) 680 681 return res 682 683 num_features = 3 684 num_words = 10 685 indices_dim1 = 10 686 687 for mode, sparse, allpad, indices_dim0 in product( 688 ["max", "mean", "sum"], [False, True], [False, True], [1, 10] 689 ): 690 # Max sparse and bfloat16 are not supported 691 if mode == "max": 692 if sparse or (dtype == torch.bfloat16): 693 continue 694 695 if allpad: 696 indices = torch.empty( 697 indices_dim0, indices_dim1, dtype=torch.long, device=device 698 ).fill_(1) 699 else: 700 indices = torch.randint( 701 0, num_words, (indices_dim0, indices_dim1), device=device 702 ) 703 704 if indices_dim0 > 1: 705 # Fill one row with duplicate index so we can test with a fully 706 # padded row 707 duplicate_row = random.randint(0, indices_dim0 - 1) 708 indices[duplicate_row] = indices[duplicate_row][0] 709 710 for padding_idx in list(set(indices.flatten(0, -1).tolist())): 711 weights = torch.randn( 712 num_words, 713 num_features, 714 dtype=dtype, 715 device=device, 716 requires_grad=True, 717 ) 718 weights_check = weights.clone().detach().requires_grad_(True) 719 720 msg = ( 721 f"mode: '{mode}', sparse: {sparse}, padding_idx: {padding_idx}, " 722 f"allpad: {allpad}, indices.size(): {indices.size()}" 723 ) 724 725 # Check forward with a Python implementation of padding_idx embedding_bag 726 bag_check = embedding_bag_check( 727 indices, weights_check, mode, sparse, padding_idx 728 ) 729 bag = torch.nn.functional.embedding_bag( 730 indices, weights, padding_idx=padding_idx, mode=mode, sparse=sparse 731 ) 732 733 self.assertEqual(bag, bag_check, msg=msg) 734 735 bag_check.sum().backward() 736 grad_check = weights_check.grad 737 738 bag.sum().backward() 739 grad = weights.grad 740 741 # Sometimes, half dtype gradients mismatch by a greater amount 742 # than other dtypes 743 if dtype in [torch.half, torch.bfloat16]: 744 atol = 0.01 745 rtol = 0.01 746 else: 747 atol = None 748 rtol = None 749 self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol) 750 751 @onlyCUDA 752 @dtypes( 753 *( 754 (torch.float, torch.double, torch.bfloat16, torch.half) 755 if TEST_WITH_ROCM 756 else (torch.float, torch.double, torch.half) 757 ) 758 ) 759 def test_embedding_max_norm_device(self, device, dtype): 760 embedding = nn.Embedding(22, 5, max_norm=1.0).to(device, dtype=dtype) 761 # nn.Embedding only takes LongTensor as input 762 input = torch.tensor([2, 8, 8, 6], device=device, dtype=torch.long) 763 output = embedding(input) 764 self.assertEqual(output[1], output[2]) 765 self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) 766 767 @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 768 def test_embedding_bag_empty_input(self, device, dtypes): 769 m = 4 770 n = 3 771 x = torch.tensor([], device=device, dtype=dtypes[0]) 772 for sparse in [True, False]: 773 Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) 774 Embed.to(device) 775 776 output = Embed( 777 input=x, offsets=torch.tensor([0], device=device, dtype=dtypes[1]) 778 ) 779 self.assertEqual(output, torch.zeros_like(output)) 780 781 output = Embed( 782 input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtypes[1]) 783 ) 784 self.assertEqual(output, torch.zeros_like(output)) 785 786 @skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.") 787 @dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long))) 788 @parametrize_test("padding_idx", [None, 0]) 789 @parametrize_test("mode", ["sum", "mean", "max"]) 790 def test_embedding_bag_out_of_bounds_idx(self, device, dtypes, padding_idx, mode): 791 padding_idx = 0 792 w_dtype, idx_dtype = dtypes 793 # negative out-of-bound 794 idx1 = torch.tensor([[-1, 1]], device=device, dtype=idx_dtype) 795 # positive out-of-bound 796 idx2 = torch.tensor([[11, 8]], device=device, dtype=idx_dtype) 797 weight = torch.randn(10, 2, device=device, dtype=w_dtype) 798 if mode == "sum": 799 # Only `sum` supports per_sample_weight 800 per_sample_weights = ( 801 None, 802 torch.randn_like(idx1, device=device, dtype=w_dtype), 803 ) 804 else: 805 per_sample_weights = (None,) 806 807 for p_s_weights, idx in itertools.product(per_sample_weights, (idx1, idx2)): 808 msg = "Expected idx >= 0 && idx < num_embeddings" 809 with self.assertRaisesRegex(RuntimeError, msg): 810 torch.nn.functional.embedding_bag( 811 idx, 812 weight, 813 per_sample_weights=p_s_weights, 814 padding_idx=padding_idx, 815 mode=mode, 816 ) 817 818 def test_embedding_bag_dimension_errors(self, device): 819 funcs = ( 820 lambda x, y, z: torch.nn.functional.embedding_bag(y, x, z), 821 torch.embedding_bag, 822 torch._embedding_bag, 823 torch._embedding_bag_forward_only, 824 ) 825 for i, f in enumerate(funcs): 826 err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError 827 828 weight = torch.full( 829 ( 830 2, 831 6, 832 ), 833 0, 834 dtype=torch.float64, 835 device=device, 836 ) 837 indices = torch.full( 838 ( 839 2, 840 0, 841 0, 842 6, 843 6, 844 ), 845 2, 846 dtype=torch.int64, 847 device=device, 848 ) 849 offsets = torch.full((2, 0, 0, 6, 6), 0, dtype=torch.int64, device=device) 850 851 if i == 0: 852 error_msg = "input has to be 1D or 2D Tensor" 853 else: 854 error_msg = "input has to be a 1D or 2D Tensor" 855 torch._dynamo.disable(self.assertRaisesRegex)( 856 err_type, error_msg, lambda: f(weight, indices, offsets) 857 ) 858 859 weight = torch.full((2, 2), 0, dtype=torch.float64, device=device) 860 indices = torch.full((2,), 1, dtype=torch.int64, device=device) 861 862 torch._dynamo.disable(self.assertRaisesRegex)( 863 err_type, 864 "offsets has to be a 1D Tensor", 865 lambda: f(weight, indices, offsets), 866 ) 867 868 weight = torch.full((2, 2, 2), 0, dtype=torch.float64, device=device) 869 indices = torch.full((2,), 2, dtype=torch.int64, device=device) 870 offsets = torch.full((2,), 0, dtype=torch.int64, device=device) 871 872 torch._dynamo.disable(self.assertRaisesRegex)( 873 err_type, 874 "weight has to be a 2D Tensor", 875 lambda: f(weight, indices, offsets), 876 ) 877 878 @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 879 def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes): 880 # Failure 1: mismatched embeddings / per_sample_weights dtype 881 es = nn.EmbeddingBag(5, 2, mode="sum").to(dtype=torch.float, device=device) 882 input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device) 883 offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device) 884 per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device) 885 if device == "cpu": 886 with self.assertRaisesRegex(RuntimeError, "have the same type as"): 887 es(input, offsets, per_sample_weights) 888 else: 889 with self.assertRaisesRegex(RuntimeError, "expected scalar type"): 890 es(input, offsets, per_sample_weights) 891 892 # Failure 2.1: input/per_sample_weights have different sizes (1d input) 893 input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device) 894 offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device) 895 per_sample_weights = torch.randn(5, dtype=torch.float, device=device) 896 with self.assertRaisesRegex(ValueError, "same shape as the input"): 897 es(input, offsets, per_sample_weights) 898 899 # Failure 2.2: input/per_sample_weights have different sizes (2d input) 900 input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device) 901 offsets = None 902 per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device) 903 with self.assertRaisesRegex(ValueError, "same shape as the input"): 904 es(input, offsets, per_sample_weights) 905 906 # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean') 907 for unsupported_mode in ("max", "mean"): 908 es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to( 909 dtype=torch.float, device=device 910 ) 911 input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device) 912 offsets = None 913 per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device) 914 with self.assertRaisesRegex( 915 NotImplementedError, "only supported for mode='sum'" 916 ): 917 es(input, offsets, per_sample_weights) 918 919 def _embedding_bag_reference_impl( 920 self, 921 input, 922 weight, 923 offsets=None, 924 mode="sum", 925 per_sample_weights=None, 926 include_last_offset=False, 927 ): 928 assert mode == "sum" or per_sample_weights is None 929 assert offsets is not None 930 if per_sample_weights is None: 931 per_sample_weights = torch.ones(input.size()).to( 932 dtype=weight.dtype, device=weight.device 933 ) 934 assert input.numel() == per_sample_weights.numel() 935 936 bags = [] 937 long_input = input.to(torch.long) 938 embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze( 939 1 940 ) 941 if include_last_offset: 942 for index in range(len(offsets) - 1): 943 offset = offsets[index] 944 next_offset = offsets[index + 1] 945 length = next_offset - offset 946 if length == 0: 947 bags.append( 948 torch.tensor([0] * weight.size(1)).to( 949 dtype=embeddings.dtype, device=embeddings.device 950 ) 951 ) 952 else: 953 if mode == "sum": 954 bags.append(embeddings.narrow(0, offset, length).sum(0)) 955 elif mode == "mean": 956 bags.append( 957 embeddings.narrow(0, offset, length).sum(0).div(length) 958 ) 959 else: 960 assert mode == "max" 961 bags.append(embeddings.narrow(0, offset, length).max(0)[0]) 962 else: 963 for index, offset in enumerate(offsets): 964 if index + 1 < len(offsets): 965 next_offset = offsets[index + 1] 966 else: 967 next_offset = len(long_input) 968 length = next_offset - offset 969 if length == 0: 970 bags.append( 971 torch.tensor([0] * weight.size(1)).to( 972 dtype=embeddings.dtype, device=embeddings.device 973 ) 974 ) 975 else: 976 if mode == "sum": 977 bags.append(embeddings.narrow(0, offset, length).sum(0)) 978 elif mode == "mean": 979 bags.append( 980 embeddings.narrow(0, offset, length).sum(0).div(length) 981 ) 982 else: 983 assert mode == "max" 984 bags.append(embeddings.narrow(0, offset, length).max(0)[0]) 985 return torch.stack(bags) 986 987 @skipMeta 988 @dtypes( 989 *itertools.product( 990 (torch.int, torch.long), 991 (torch.int, torch.long), 992 (torch.half, torch.bfloat16, torch.float, torch.double), 993 ) 994 ) 995 @dtypesIfCUDA( 996 *itertools.product( 997 (torch.int, torch.long), 998 (torch.int, torch.long), 999 (torch.float, torch.double, torch.half), 1000 ) 1001 ) 1002 def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes): 1003 # Test empty input and per sample weight, and backward pass. There was a CUDA 1004 # invalid configuration bug (more context in #46572) 1005 def test_per_sample_weights(mode, trainable_scale): 1006 es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device) 1007 es.weight.data.copy_( 1008 torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2]) 1009 ) 1010 input = torch.tensor([], device=device, dtype=dtypes[0]) 1011 offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[1]) 1012 per_sample_weights = torch.randn_like( 1013 input, dtype=dtypes[2] 1014 ).requires_grad_(trainable_scale) 1015 ref_per_sample_weights = per_sample_weights.detach().requires_grad_( 1016 trainable_scale 1017 ) 1018 reference_weights = es.weight.detach().requires_grad_() 1019 1020 expected = self._embedding_bag_reference_impl( 1021 input, reference_weights, offsets, mode, ref_per_sample_weights 1022 ) 1023 result = es(input, offsets, per_sample_weights) 1024 self.assertEqual( 1025 result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0 1026 ) 1027 1028 grad = torch.randn_like(expected) 1029 result.backward(grad) 1030 # the reference impl doesn't have grad fn for empty input; but the grad should 1031 # simply be a zero tensor 1032 ref_weights_grad = torch.zeros_like(es.weight) 1033 self.assertEqual( 1034 es.weight.grad, 1035 ref_weights_grad, 1036 atol=dtype2prec_DONTUSE[dtypes[2]], 1037 rtol=0, 1038 ) 1039 if trainable_scale: 1040 ref_per_sample_weights_grad = torch.empty_like(per_sample_weights) 1041 self.assertEqual( 1042 per_sample_weights.grad, 1043 ref_per_sample_weights_grad, 1044 atol=dtype2prec_DONTUSE[dtypes[2]], 1045 rtol=0, 1046 ) 1047 1048 modes = ("sum",) 1049 trainable_scale = (True, False) 1050 for mode, trainable in itertools.product(modes, trainable_scale): 1051 test_per_sample_weights(mode, trainable) 1052 1053 @skipMeta 1054 @dtypes( 1055 *itertools.product( 1056 (torch.int, torch.long), 1057 (torch.int, torch.long), 1058 (torch.float, torch.double, torch.half, torch.bfloat16), 1059 ) 1060 ) 1061 @dtypesIfCUDA( 1062 *itertools.product( 1063 (torch.int, torch.long), 1064 (torch.int, torch.long), 1065 (torch.float, torch.double, torch.half), 1066 ) 1067 ) 1068 def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes): 1069 def test_per_sample_weights(mode, trainable_scale): 1070 es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device) 1071 es.weight.data.copy_( 1072 torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2]) 1073 ) 1074 input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) 1075 offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1]) 1076 per_sample_weights = torch.randn_like( 1077 input, dtype=dtypes[2] 1078 ).requires_grad_(trainable_scale) 1079 ref_per_sample_weights = per_sample_weights.detach().requires_grad_( 1080 trainable_scale 1081 ) 1082 reference_weights = es.weight.detach().requires_grad_() 1083 1084 expected = self._embedding_bag_reference_impl( 1085 input, reference_weights, offsets, mode, ref_per_sample_weights 1086 ) 1087 result = es(input, offsets, per_sample_weights) 1088 self.assertEqual( 1089 result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0 1090 ) 1091 1092 grad = torch.randn_like(expected).to(dtype=dtypes[2], device=device) 1093 result.backward(grad) 1094 expected.backward(grad) 1095 self.assertEqual( 1096 es.weight.grad, 1097 reference_weights.grad, 1098 atol=dtype2prec_DONTUSE[dtypes[2]], 1099 rtol=0, 1100 ) 1101 if trainable_scale: 1102 self.assertEqual( 1103 per_sample_weights.grad, 1104 ref_per_sample_weights.grad, 1105 atol=dtype2prec_DONTUSE[dtypes[2]], 1106 rtol=0, 1107 ) 1108 1109 modes = ("sum",) 1110 trainable_scale = (True, False) 1111 for mode, trainable in itertools.product(modes, trainable_scale): 1112 test_per_sample_weights(mode, trainable) 1113 1114 @skipMeta 1115 @dtypes( 1116 *itertools.product( 1117 (torch.int, torch.long), 1118 (torch.int, torch.long), 1119 (torch.float, torch.double, torch.half, torch.bfloat16), 1120 ) 1121 ) 1122 @dtypesIfCUDA( 1123 *itertools.product( 1124 (torch.int, torch.long), 1125 (torch.int, torch.long), 1126 (torch.float, torch.double, torch.half), 1127 ) 1128 ) 1129 def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes): 1130 def test_per_sample_weights_new_offsets( 1131 mode, trainable_scale, include_last_offset, has_weight=True 1132 ): 1133 es = nn.EmbeddingBag( 1134 5, 2, mode=mode, include_last_offset=include_last_offset 1135 ).to(dtype=dtypes[2], device=device) 1136 es.weight.data.copy_( 1137 torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2]) 1138 ) 1139 input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) 1140 offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1]) 1141 1142 if include_last_offset: 1143 offsets = torch.cat( 1144 ( 1145 offsets, 1146 torch.tensor([input.size(0)], device=device, dtype=dtypes[1]), 1147 ), 1148 0, 1149 ) 1150 1151 if has_weight: 1152 per_sample_weights = torch.randn_like( 1153 input, device=device, dtype=dtypes[2] 1154 ).requires_grad_(trainable_scale) 1155 ref_per_sample_weights = per_sample_weights.detach().requires_grad_( 1156 trainable_scale 1157 ) 1158 else: 1159 per_sample_weights = None 1160 ref_per_sample_weights = None 1161 1162 reference_weights = es.weight.detach().requires_grad_() 1163 1164 expected = self._embedding_bag_reference_impl( 1165 input, 1166 reference_weights, 1167 offsets, 1168 mode, 1169 ref_per_sample_weights, 1170 include_last_offset, 1171 ) 1172 result = es(input, offsets, per_sample_weights) 1173 self.assertEqual( 1174 result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0 1175 ) 1176 1177 grad = torch.randn_like(expected) 1178 result.backward(grad) 1179 expected.backward(grad) 1180 self.assertEqual( 1181 es.weight.grad, 1182 reference_weights.grad, 1183 atol=dtype2prec_DONTUSE[dtypes[2]], 1184 rtol=0, 1185 ) 1186 if has_weight and trainable_scale: 1187 self.assertEqual( 1188 per_sample_weights.grad, 1189 ref_per_sample_weights.grad, 1190 atol=dtype2prec_DONTUSE[dtypes[2]], 1191 rtol=0, 1192 ) 1193 1194 trainable_scale = (True, False) 1195 include_last_offset_list = (True, False) 1196 modes = (("sum", False), ("sum", True), ("max", False), ("mean", False)) 1197 for (mode, has_weight), trainable, include_last_offset in itertools.product( 1198 modes, trainable_scale, include_last_offset_list 1199 ): 1200 test_per_sample_weights_new_offsets( 1201 mode, trainable, include_last_offset, has_weight 1202 ) 1203 1204 def _test_EmbeddingBag_vs_Embedding( 1205 self, 1206 N, 1207 D, 1208 B, 1209 L, 1210 max_norm=None, 1211 mode="mean", 1212 device="cpu", 1213 wdtype=torch.float, 1214 dtype=torch.long, 1215 test_per_sample_weights=False, 1216 trainable_per_sample_weights=False, 1217 sparse=False, 1218 test_backward=True, 1219 backward_prec=None, 1220 ): 1221 es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to( 1222 device, wdtype 1223 ) 1224 e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype) 1225 e.weight.data.copy_(es.weight) 1226 input = torch.randint(N, (B, L), device=device, dtype=dtype) 1227 offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L) 1228 grad_output = torch.rand(B, D, device=device, dtype=wdtype) 1229 1230 if test_per_sample_weights: 1231 # To prevent large gradients, weights should sum to 1 for each bag 1232 per_sample_weights = torch.randn(B, L, device=device, dtype=wdtype).softmax( 1233 dim=-1 1234 ) 1235 per_sample_weights_reference = per_sample_weights.clone().requires_grad_( 1236 trainable_per_sample_weights 1237 ) 1238 per_sample_weights.requires_grad_(trainable_per_sample_weights) 1239 output = es(input.view(-1), offsets, per_sample_weights.view(-1)) 1240 else: 1241 output = es(input.view(-1), offsets) 1242 per_sample_weights = None 1243 per_sample_weights_reference = None 1244 1245 if mode == "sum": 1246 if test_per_sample_weights: 1247 ref_output = ( 1248 e(input) * per_sample_weights_reference.unsqueeze(-1) 1249 ).sum(1) 1250 else: 1251 ref_output = e(input).sum(1) 1252 elif mode == "mean": 1253 assert not test_per_sample_weights 1254 ref_output = e(input).mean(1) 1255 elif mode == "max": 1256 assert not test_per_sample_weights 1257 ref_output = e(input).max(1)[0] 1258 1259 self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0) 1260 1261 if not test_backward: 1262 return 1263 1264 output.backward(grad_output) 1265 ref_output.backward(grad_output) 1266 es_weight_grad = es.weight.grad 1267 if sparse: 1268 es_weight_grad = es.weight.grad.to_dense() 1269 1270 # We have more floating point error here because we are dealing with larger numbers 1271 if backward_prec is None: 1272 needed_prec = dtype2prec_DONTUSE[wdtype] * 5 1273 rtol = 0.02 if wdtype == torch.half else 0 1274 else: 1275 needed_prec = backward_prec 1276 rtol = 0 1277 1278 self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=rtol) 1279 1280 if test_per_sample_weights and trainable_per_sample_weights: 1281 self.assertEqual( 1282 per_sample_weights.grad, 1283 per_sample_weights_reference.grad, 1284 atol=dtype2prec_DONTUSE[wdtype], 1285 rtol=0, 1286 ) 1287 1288 @dtypesIfCUDA( 1289 *itertools.product( 1290 (torch.int, torch.long), (torch.half, torch.float, torch.double) 1291 ) 1292 ) 1293 @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) 1294 def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes): 1295 def run_tests(mode, sparse, trainable_per_sample_weights): 1296 kwargs = dict( 1297 test_per_sample_weights=True, 1298 device=device, 1299 mode=mode, 1300 wdtype=dtypes[1], 1301 dtype=dtypes[0], 1302 sparse=sparse, 1303 trainable_per_sample_weights=trainable_per_sample_weights, 1304 ) 1305 1306 # Simple case 1307 self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs) 1308 1309 # B * L > 1000 1310 self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs) 1311 1312 # Large num_embedding 1313 self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs) 1314 1315 # Large embedding_dim 1316 self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) 1317 1318 modes = ("sum",) 1319 sparsity = (True, False) 1320 trainable_scale = (True, False) 1321 for mode, sparse, trainable_per_sample_weights in itertools.product( 1322 modes, sparsity, trainable_scale 1323 ): 1324 run_tests(mode, sparse, trainable_per_sample_weights) 1325 1326 # Test CUDA Dense on half precision 1327 if device == "cuda": 1328 modes = ("sum",) 1329 sparsity = (False,) 1330 trainable_scale = (True, False) 1331 for mode, sparse, trainable_per_sample_weights in itertools.product( 1332 modes, sparsity, trainable_scale 1333 ): 1334 run_tests(mode, sparse, trainable_per_sample_weights) 1335 1336 def _test_EmbeddingBag( 1337 self, 1338 device, 1339 mode, 1340 sparse, 1341 wdtype=torch.double, 1342 dtype=torch.long, 1343 odtype=torch.long, 1344 test_backward=True, 1345 ): 1346 # check a known test example 1347 es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype) 1348 es.weight.data.copy_( 1349 torch.arange(1, 11, device=device).view_as(es.weight).to(wdtype) 1350 ) 1351 input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype) 1352 offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=odtype) 1353 1354 grad_output = torch.tensor([1, 2, 3, 4], device=device, dtype=wdtype).view(2, 2) 1355 grad_output_with_empty = torch.tensor( 1356 [99, 99, 1, 2, 99, 99, 3, 4, 99, 99], device=device, dtype=wdtype 1357 ).view(5, 2) 1358 1359 if mode == "sum" or mode == "mean": 1360 denominator = 1 if mode == "sum" else 3 1361 expected_output = ( 1362 torch.tensor([[13, 16], [13, 16]], device=device, dtype=wdtype) 1363 / denominator 1364 ) 1365 1366 expected_output_with_empty = ( 1367 torch.tensor( 1368 [[0, 0], [13, 16], [0, 0], [13, 16], [0, 0]], 1369 device=device, 1370 dtype=wdtype, 1371 ) 1372 / denominator 1373 ) 1374 1375 expected_grad_weight = ( 1376 torch.tensor( 1377 [[3, 4], [5, 8], [0, 0], [1, 2], [3, 4]], 1378 device=device, 1379 dtype=wdtype, 1380 ) 1381 / denominator 1382 ) 1383 elif mode == "max": 1384 expected_output = torch.tensor( 1385 [[7, 8], [9, 10]], device=device, dtype=wdtype 1386 ) 1387 1388 expected_output_with_empty = torch.tensor( 1389 [[0, 0], [7, 8], [0, 0], [9, 10], [0, 0]], device=device, dtype=wdtype 1390 ) 1391 1392 expected_grad_weight = torch.tensor( 1393 [[0, 0], [0, 0], [0, 0], [1, 2], [3, 4]], device=device, dtype=wdtype 1394 ) 1395 output = es(input, offsets) 1396 output.backward(grad_output_with_empty) 1397 1398 es_weight_grad = es.weight.grad 1399 if sparse: 1400 es_weight_grad = es.weight.grad.to_dense() 1401 self.assertEqual(output, expected_output_with_empty) 1402 self.assertEqual( 1403 es_weight_grad, 1404 expected_grad_weight, 1405 atol=dtype2prec_DONTUSE[wdtype], 1406 rtol=0, 1407 ) 1408 1409 # check same example except as 2D (2 x 3) 1410 input = input.view(2, -1) 1411 es.zero_grad() 1412 output = es(input) 1413 output.backward(grad_output) 1414 1415 es_weight_grad = es.weight.grad 1416 if sparse: 1417 es_weight_grad = es.weight.grad.to_dense() 1418 self.assertEqual(output, expected_output) 1419 self.assertEqual( 1420 es_weight_grad, 1421 expected_grad_weight, 1422 atol=dtype2prec_DONTUSE[wdtype], 1423 rtol=0, 1424 ) 1425 1426 # test all empty bags 1427 es.zero_grad() 1428 inputs = torch.tensor([], dtype=dtype, device=device) 1429 offsets = torch.tensor([0, 0, 0, 0], dtype=odtype, device=device) 1430 es(inputs, offsets).sum().backward() 1431 dense_grad = es.weight.grad 1432 if dense_grad.is_sparse: 1433 dense_grad = dense_grad.to_dense() 1434 self.assertEqual(dense_grad, torch.zeros_like(es.weight)) 1435 1436 # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length 1437 N, D, B, L = ( 1438 random.randint(1, 100), 1439 random.randint(1, 100), 1440 random.randint(1, 50), 1441 random.randint(1, 50), 1442 ) 1443 kwargs = dict( 1444 mode=mode, 1445 sparse=sparse, 1446 device=device, 1447 wdtype=wdtype, 1448 dtype=dtype, 1449 test_backward=test_backward, 1450 ) 1451 self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs) 1452 for max_norm in (None, 3): 1453 for p in itertools.product([1, 2], repeat=4): 1454 self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs) 1455 1456 # check that giving illegal input combos raises error 1457 es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse) 1458 input = torch.ones(3, 4, dtype=dtype) 1459 offset = torch.arange(0, 3, dtype=odtype) 1460 torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input, offset)) 1461 torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input.view(-1))) 1462 offset[0] = 1 1463 if self.device_type == "cpu": 1464 torch._dynamo.disable(self.assertRaises)( 1465 RuntimeError, lambda: es(input.view(-1), offset) 1466 ) 1467 offset[0] = 0 1468 offset[-1] = 100 1469 torch._dynamo.disable(self.assertRaises)( 1470 RuntimeError, lambda: es(input.view(-1), offset) 1471 ) 1472 1473 @skipMeta 1474 @dtypes( 1475 *itertools.product( 1476 (torch.int, torch.long), 1477 (torch.int, torch.long), 1478 (torch.float, torch.double, torch.half, torch.bfloat16), 1479 ) 1480 ) 1481 @dtypesIfCUDA( 1482 *itertools.product( 1483 (torch.int, torch.long), 1484 (torch.int, torch.long), 1485 (torch.float, torch.double, torch.half), 1486 ) 1487 ) 1488 def test_embedding_bag_device(self, device, dtypes): 1489 if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu": 1490 self.skipTest("bfloat16 not supported with Jetson cpu") 1491 with set_default_dtype(torch.double): 1492 self._test_EmbeddingBag( 1493 device, 1494 "sum", 1495 False, 1496 wdtype=dtypes[2], 1497 dtype=dtypes[0], 1498 odtype=dtypes[1], 1499 ) 1500 self._test_EmbeddingBag( 1501 device, 1502 "mean", 1503 False, 1504 wdtype=dtypes[2], 1505 dtype=dtypes[0], 1506 odtype=dtypes[1], 1507 ) 1508 self._test_EmbeddingBag( 1509 device, 1510 "max", 1511 False, 1512 wdtype=dtypes[2], 1513 dtype=dtypes[0], 1514 odtype=dtypes[1], 1515 ) 1516 1517 test_backward = False 1518 if self.device_type == "cuda": 1519 # see 'todo' in test_embedding_bag. 1520 test_backward = dtypes[2] is not torch.float16 1521 elif self.device_type == "cpu": 1522 # TODO: figure out why precision on sparse embeddings isn't the 1523 # same as for dense. 1524 test_backward = ( 1525 dtypes[2] is not torch.float and dtypes[2] is not torch.float16 1526 ) 1527 1528 self._test_EmbeddingBag( 1529 device, 1530 "sum", 1531 True, 1532 wdtype=dtypes[2], 1533 dtype=dtypes[0], 1534 odtype=dtypes[1], 1535 test_backward=test_backward, 1536 ) 1537 self._test_EmbeddingBag( 1538 device, 1539 "mean", 1540 True, 1541 wdtype=dtypes[2], 1542 dtype=dtypes[0], 1543 odtype=dtypes[1], 1544 test_backward=test_backward, 1545 ) 1546 1547 @skipMeta 1548 @dtypes( 1549 *itertools.product( 1550 (torch.int, torch.long), 1551 (torch.int, torch.long), 1552 (torch.float, torch.double, torch.half, torch.bfloat16), 1553 ) 1554 ) 1555 @dtypesIfCUDA( 1556 *itertools.product( 1557 (torch.int, torch.long), 1558 (torch.int, torch.long), 1559 (torch.float, torch.double, torch.half), 1560 ) 1561 ) 1562 def test_embedding_bag_non_contiguous_weight(self, device, dtypes): 1563 weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device) 1564 1565 weight_tensor_non_contig = weight_tensor[ 1566 :, :3 1567 ] # This is non-contiguous strided. 1568 weight_tensor_contig = ( 1569 weight_tensor_non_contig.clone().contiguous() 1570 ) # Contig-strided. 1571 1572 index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device) 1573 offsets = torch.tensor([0, 2], dtype=dtypes[1], device=device) 1574 for mode in ["sum", "mean", "max"]: 1575 output_non_contig = F.embedding_bag( 1576 input=index, 1577 weight=weight_tensor_non_contig, 1578 offsets=offsets, 1579 mode=mode, 1580 ) 1581 output_contig = F.embedding_bag( 1582 input=index, 1583 weight=weight_tensor_contig, 1584 offsets=offsets, 1585 mode=mode, 1586 ) 1587 self.assertEqual(output_non_contig, output_contig) 1588 1589 @onlyNativeDeviceTypes # currently fails on XLA 1590 @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 1591 def test_embedding_bag_bfloat16(self, device, dtypes): 1592 with set_default_dtype(torch.double): 1593 self._test_EmbeddingBag( 1594 device, 1595 "sum", 1596 True, 1597 wdtype=torch.bfloat16, 1598 dtype=dtypes[0], 1599 odtype=dtypes[1], 1600 test_backward=True, 1601 ) 1602 self._test_EmbeddingBag( 1603 device, 1604 "mean", 1605 True, 1606 wdtype=torch.bfloat16, 1607 dtype=dtypes[0], 1608 odtype=dtypes[1], 1609 test_backward=True, 1610 ) 1611 1612 @onlyNativeDeviceTypes # currently fails on XLA 1613 @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 1614 def test_embedding_bag_half(self, device, dtypes): 1615 self._test_EmbeddingBag( 1616 device, 1617 "sum", 1618 True, 1619 wdtype=torch.float16, 1620 dtype=dtypes[0], 1621 odtype=dtypes[1], 1622 test_backward=True, 1623 ) 1624 1625 1626instantiate_device_type_tests(TestEmbeddingNNDeviceType, globals()) 1627instantiate_parametrized_tests(TestEmbeddingNN) 1628 1629if __name__ == "__main__": 1630 run_tests() 1631