1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport itertools 3*da0073e9SAndroid Build Coastguard Workerimport random 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 12*da0073e9SAndroid Build Coastguard Worker dtypes, 13*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 14*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 15*da0073e9SAndroid Build Coastguard Worker largeTensorTest, 16*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 17*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 18*da0073e9SAndroid Build Coastguard Worker skipCUDAIf, 19*da0073e9SAndroid Build Coastguard Worker skipMeta, 20*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 21*da0073e9SAndroid Build Coastguard Worker) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 25*da0073e9SAndroid Build Coastguard Worker _assertGradAndGradgradChecks, 26*da0073e9SAndroid Build Coastguard Worker dtype2prec_DONTUSE, 27*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 28*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 29*da0073e9SAndroid Build Coastguard Worker parametrize as parametrize_test, 30*da0073e9SAndroid Build Coastguard Worker run_tests, 31*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 32*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 33*da0073e9SAndroid Build Coastguard Worker) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerclass TestEmbeddingNN(NNTestCase): 37*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 38*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 41*da0073e9SAndroid Build Coastguard Worker def test_embedding_max_norm_unsorted_repeating_indices(self): 42*da0073e9SAndroid Build Coastguard Worker def create_embedding(device): 43*da0073e9SAndroid Build Coastguard Worker # Seed RNG so we get the same Embedding each time 44*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 45*da0073e9SAndroid Build Coastguard Worker return torch.nn.Embedding( 46*da0073e9SAndroid Build Coastguard Worker num_embeddings=20, embedding_dim=64, max_norm=1.0 47*da0073e9SAndroid Build Coastguard Worker ).to(device) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker ix = torch.arange(2, device="cpu", dtype=torch.long).repeat(2000) 50*da0073e9SAndroid Build Coastguard Worker out_cpu = create_embedding("cpu")(ix) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker ix = ix.to("cuda") 53*da0073e9SAndroid Build Coastguard Worker out = create_embedding("cuda")(ix) 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.cpu(), out_cpu) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def test_embedding_sparse_basic(self): 57*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 20, sparse=True) 58*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long) 59*da0073e9SAndroid Build Coastguard Worker embedding(input).sum().backward() 60*da0073e9SAndroid Build Coastguard Worker self.assertTrue(embedding.weight.grad.is_sparse) 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def test_embedding_sparse_empty_tensor(self): 64*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(0, 0, sparse=True) 65*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([], dtype=torch.int64) 66*da0073e9SAndroid Build Coastguard Worker embedding(input).sum().backward() 67*da0073e9SAndroid Build Coastguard Worker self.assertTrue(embedding.weight.grad.is_sparse) 68*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 0, sparse=True) 71*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]) 72*da0073e9SAndroid Build Coastguard Worker embedding(input).sum().backward() 73*da0073e9SAndroid Build Coastguard Worker self.assertTrue(embedding.weight.grad.is_sparse) 74*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def test_move_sparse_half_embedding(self): 77*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 3, sparse=True) 78*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.device.type, "cpu") 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.dtype, torch.get_default_dtype()) 80*da0073e9SAndroid Build Coastguard Worker embedding.to(torch.float16) 81*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.dtype, torch.float16) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.embedding_dim, 3) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.num_embeddings, 10) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 86*da0073e9SAndroid Build Coastguard Worker embedding.to("cuda") 87*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.device.type, "cuda") 88*da0073e9SAndroid Build Coastguard Worker embedding.to("cpu") 89*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.device.type, "cpu") 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def test_embedding_max_norm(self): 92*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(22, 5, max_norm=1.0) 93*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([2, 8, 8, 6], dtype=torch.long) 94*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1], output[2]) 96*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker @parametrize_test( 99*da0073e9SAndroid Build Coastguard Worker "dtype", 100*da0073e9SAndroid Build Coastguard Worker ( 101*da0073e9SAndroid Build Coastguard Worker torch.uint8, 102*da0073e9SAndroid Build Coastguard Worker torch.int8, 103*da0073e9SAndroid Build Coastguard Worker torch.int16, 104*da0073e9SAndroid Build Coastguard Worker torch.int32, 105*da0073e9SAndroid Build Coastguard Worker torch.int64, 106*da0073e9SAndroid Build Coastguard Worker torch.float, 107*da0073e9SAndroid Build Coastguard Worker torch.double, 108*da0073e9SAndroid Build Coastguard Worker ), 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker def test_embedding_from_pretrained(self, dtype): 111*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 112*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding.from_pretrained(a) 113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, embedding.weight.data) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor([0, 1]) 116*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, output) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_from_pretrained(self): 120*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 121*da0073e9SAndroid Build Coastguard Worker embedding = nn.EmbeddingBag.from_pretrained(a) 122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, embedding.weight) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([0, 1], dtype=torch.long) 125*da0073e9SAndroid Build Coastguard Worker output = embedding(input, torch.arange(input.size(0))) 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, output) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def test_embedding_from_pretrained_padding_idx(self): 129*da0073e9SAndroid Build Coastguard Worker padding_idx = 2 130*da0073e9SAndroid Build Coastguard Worker padding_vec = torch.ones(3) * 7 131*da0073e9SAndroid Build Coastguard Worker embeddings = torch.rand(4, 3, requires_grad=True) 132*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 133*da0073e9SAndroid Build Coastguard Worker embeddings[padding_idx] = padding_vec 134*da0073e9SAndroid Build Coastguard Worker embedding_nn = nn.Embedding.from_pretrained(embeddings, padding_idx=padding_idx) 135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding_nn.weight[padding_idx], padding_vec) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_from_pretrained_padding_idx(self): 138*da0073e9SAndroid Build Coastguard Worker padding_idx = 2 139*da0073e9SAndroid Build Coastguard Worker embeddings = torch.rand(4, 3, requires_grad=True) 140*da0073e9SAndroid Build Coastguard Worker embedding_nn = nn.EmbeddingBag.from_pretrained( 141*da0073e9SAndroid Build Coastguard Worker embeddings, padding_idx=padding_idx 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding_nn.weight, embeddings) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def test_embedding_from_pretrained_options(self): 146*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 147*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 148*da0073e9SAndroid Build Coastguard Worker opts = { 149*da0073e9SAndroid Build Coastguard Worker "max_norm": 2.0, 150*da0073e9SAndroid Build Coastguard Worker "norm_type": 0.5, 151*da0073e9SAndroid Build Coastguard Worker "scale_grad_by_freq": False, 152*da0073e9SAndroid Build Coastguard Worker "sparse": True, 153*da0073e9SAndroid Build Coastguard Worker } 154*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding.from_pretrained(a, **opts) 155*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor([0, 1]) 156*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 157*da0073e9SAndroid Build Coastguard Worker # test output and that weight matrix was renormalized 158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, output) 159*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()) 160*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 161*da0073e9SAndroid Build Coastguard Worker output.data.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all() 162*da0073e9SAndroid Build Coastguard Worker ) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def test_embedding_functional(self): 165*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long) 166*da0073e9SAndroid Build Coastguard Worker embeddings = torch.rand(4, 3, requires_grad=True) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker embed_old = torch.nn.Embedding(4, 3) 169*da0073e9SAndroid Build Coastguard Worker embed_old.weight.data = embeddings.data 170*da0073e9SAndroid Build Coastguard Worker # A silly test for eager, this test is useful for when we run under PYTORCH_TEST_WITH_DYNAMO=1 171*da0073e9SAndroid Build Coastguard Worker # as it ensures that setattr correctly works. 172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embed_old.weight.data, embeddings.data) 173*da0073e9SAndroid Build Coastguard Worker res_old = embed_old(a) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker res_F = F.embedding(a, embeddings) 176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_old, res_F) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker embed_old = torch.nn.Embedding(4, 3) 179*da0073e9SAndroid Build Coastguard Worker embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) 180*da0073e9SAndroid Build Coastguard Worker res_old = embed_old(a) 181*da0073e9SAndroid Build Coastguard Worker res_F = F.embedding(a, embeddings, padding_idx=2) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_old, res_F) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/130806 186*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("40GB", device="cuda") 187*da0073e9SAndroid Build Coastguard Worker def test_large_tensors(self): 188*da0073e9SAndroid Build Coastguard Worker input = torch.randint(low=0, high=16032, size=[131072], device="cuda") 189*da0073e9SAndroid Build Coastguard Worker w = torch.randn([16032, 16384], device="cuda") 190*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.embedding(input, w) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dim(), 2) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.numel(), 2147483648) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_functional(self): 195*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long) 196*da0073e9SAndroid Build Coastguard Worker embeddings = torch.rand(4, 3, requires_grad=True) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker embed_old = torch.nn.EmbeddingBag(4, 3) 199*da0073e9SAndroid Build Coastguard Worker embed_old.weight = torch.nn.Parameter(embeddings) 200*da0073e9SAndroid Build Coastguard Worker res_old = embed_old(a) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker res_F = F.embedding_bag(a, embeddings) 203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_old, res_F) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker embed_old = torch.nn.EmbeddingBag(4, 3) 206*da0073e9SAndroid Build Coastguard Worker embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) 207*da0073e9SAndroid Build Coastguard Worker res_old = embed_old(a) 208*da0073e9SAndroid Build Coastguard Worker res_F = F.embedding_bag(a, embeddings, padding_idx=2) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_old, res_F) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker # Make sure that error is thrown if padding_idx is out of bounds 213*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_padding_idx_error(self): 214*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=torch.long) 215*da0073e9SAndroid Build Coastguard Worker num_embeddings = 4 216*da0073e9SAndroid Build Coastguard Worker num_features = 3 217*da0073e9SAndroid Build Coastguard Worker embeddings = torch.rand(num_embeddings, num_features, requires_grad=True) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker functional_err_msg = r"padding_idx must be within the number of embeddings" 220*da0073e9SAndroid Build Coastguard Worker module_err_msg = r"padding_idx must be within num_embeddings" 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker for padding_idx in range(-(num_embeddings + 2), (num_embeddings + 2)): 223*da0073e9SAndroid Build Coastguard Worker if (padding_idx < -num_embeddings) or (padding_idx >= num_embeddings): 224*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, functional_err_msg): 225*da0073e9SAndroid Build Coastguard Worker F.embedding_bag(a, embeddings, padding_idx=padding_idx) 226*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, module_err_msg): 227*da0073e9SAndroid Build Coastguard Worker torch.nn.EmbeddingBag( 228*da0073e9SAndroid Build Coastguard Worker num_embeddings, num_features, padding_idx=padding_idx 229*da0073e9SAndroid Build Coastguard Worker ) 230*da0073e9SAndroid Build Coastguard Worker else: 231*da0073e9SAndroid Build Coastguard Worker F.embedding_bag(a, embeddings, padding_idx=padding_idx) 232*da0073e9SAndroid Build Coastguard Worker torch.nn.EmbeddingBag( 233*da0073e9SAndroid Build Coastguard Worker num_embeddings, num_features, padding_idx=padding_idx 234*da0073e9SAndroid Build Coastguard Worker ) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker def test_embeddingbag_from_pretrained(self): 237*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 238*da0073e9SAndroid Build Coastguard Worker embeddingbag = nn.EmbeddingBag.from_pretrained(a) 239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, embeddingbag.weight.data) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor([[0, 1]]) 242*da0073e9SAndroid Build Coastguard Worker output = embeddingbag(input) 243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.mean(0, keepdim=True), output) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker def test_embeddingbag_from_pretrained_options(self): 246*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 247*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 248*da0073e9SAndroid Build Coastguard Worker opts = { 249*da0073e9SAndroid Build Coastguard Worker "max_norm": 2.0, 250*da0073e9SAndroid Build Coastguard Worker "norm_type": 0.5, 251*da0073e9SAndroid Build Coastguard Worker "scale_grad_by_freq": False, 252*da0073e9SAndroid Build Coastguard Worker "mode": "max", 253*da0073e9SAndroid Build Coastguard Worker "sparse": False, 254*da0073e9SAndroid Build Coastguard Worker } 255*da0073e9SAndroid Build Coastguard Worker embeddingbag = nn.EmbeddingBag.from_pretrained(a, **opts) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker input = torch.LongTensor([[0, 1]]) 258*da0073e9SAndroid Build Coastguard Worker output = embeddingbag(input) 259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.max(0, keepdim=True)[0], output) 260*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()) 261*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 262*da0073e9SAndroid Build Coastguard Worker a.norm(p=opts["norm_type"], dim=1).le(opts["max_norm"]).all() 263*da0073e9SAndroid Build Coastguard Worker ) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker def test_embeddingbag_include_last_offset(self): 266*da0073e9SAndroid Build Coastguard Worker # Test case from https://github.com/pytorch/pytorch/issues/89677 267*da0073e9SAndroid Build Coastguard Worker embeddingbag = nn.EmbeddingBag(100, 3, include_last_offset=True, padding_idx=61) 268*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([0, 1, 2, 3]) 269*da0073e9SAndroid Build Coastguard Worker out = embeddingbag(input, torch.tensor([0, 3, 3])) 270*da0073e9SAndroid Build Coastguard Worker out2 = embeddingbag(input, torch.tensor([0, 3, 4])) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker weight = embeddingbag.weight 273*da0073e9SAndroid Build Coastguard Worker row0 = weight[0:3].mean(0) 274*da0073e9SAndroid Build Coastguard Worker row1 = weight[3] 275*da0073e9SAndroid Build Coastguard Worker ref_out = torch.stack([row0, row1]) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out, out) 278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_out, out2) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Workerclass TestEmbeddingNNDeviceType(NNTestCase): 282*da0073e9SAndroid Build Coastguard Worker def test_embedding_dense_grad(self, device): 283*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 284*da0073e9SAndroid Build Coastguard Worker embd = nn.Embedding(20, 20).to(device) 285*da0073e9SAndroid Build Coastguard Worker weight = embd.weight 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker def fn_wrapper(device): 288*da0073e9SAndroid Build Coastguard Worker def fn(weight): 289*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor( 290*da0073e9SAndroid Build Coastguard Worker [[0, 1, 1, 2], [3, 5, 7, 11]], dtype=torch.long 291*da0073e9SAndroid Build Coastguard Worker ).to(device) 292*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.embedding(inp, weight) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker return fn 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker fn = fn_wrapper(device) 297*da0073e9SAndroid Build Coastguard Worker _assertGradAndGradgradChecks(self, fn, (weight,)) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker def test_embedding_scalar_weight_error(self, device): 300*da0073e9SAndroid Build Coastguard Worker indices = torch.rand(2, 2, device=device).long() 301*da0073e9SAndroid Build Coastguard Worker weights = [ 302*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0, device=device), 303*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0, device=device).reshape(1, 1, 1), 304*da0073e9SAndroid Build Coastguard Worker ] 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker for weight in weights: 307*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"): 308*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.embedding(indices, weight) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float16, torch.float64) 311*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64) 312*da0073e9SAndroid Build Coastguard Worker def test_embedding_backward(self, device, dtype): 313*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 3, sparse=True) 314*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([[7, 1, 3]]) 315*da0073e9SAndroid Build Coastguard Worker ones = torch.tensor(1.0, dtype=dtype).expand(3, 3) 316*da0073e9SAndroid Build Coastguard Worker tensorTwice = tensor.repeat(1, 2) 317*da0073e9SAndroid Build Coastguard Worker onesTwice = torch.cat((ones, ones)) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker embedding = embedding.to(dtype=dtype).to(device) 320*da0073e9SAndroid Build Coastguard Worker tensor = tensor.to(device) 321*da0073e9SAndroid Build Coastguard Worker ones = ones.to(device) 322*da0073e9SAndroid Build Coastguard Worker tensorTwice = tensorTwice.to(device) 323*da0073e9SAndroid Build Coastguard Worker onesTwice = onesTwice.to(device) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker embedding.zero_grad() 326*da0073e9SAndroid Build Coastguard Worker embedding(tensor[0]).sum().backward() 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad._indices(), tensor) 328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad._values(), ones) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker embedding.zero_grad() 331*da0073e9SAndroid Build Coastguard Worker embedding(tensor[0]).sum().backward() 332*da0073e9SAndroid Build Coastguard Worker embedding(tensor[0]).sum().backward() 333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad._indices(), tensorTwice) 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad._values(), onesTwice) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker embedding.zero_grad() 337*da0073e9SAndroid Build Coastguard Worker embedding(tensor[0]).sum().backward() 338*da0073e9SAndroid Build Coastguard Worker tensor[0, 0] = 8 339*da0073e9SAndroid Build Coastguard Worker embedding(tensor[0]).sum().backward() 340*da0073e9SAndroid Build Coastguard Worker tensorTwice[0, 3] = 8 341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad._indices(), tensorTwice) 342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(embedding.weight.grad._values(), onesTwice) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 345*da0073e9SAndroid Build Coastguard Worker *( 346*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.bfloat16, torch.half) 347*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM 348*da0073e9SAndroid Build Coastguard Worker else (torch.float, torch.double, torch.half) 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 352*da0073e9SAndroid Build Coastguard Worker def test_embedding_max_norm_backward(self, device, dtype): 353*da0073e9SAndroid Build Coastguard Worker # can't use gradcheck since in place renorm makes analytical gradients different from produced ones 354*da0073e9SAndroid Build Coastguard Worker weight = torch.randn((4, 4), device=device, dtype=dtype) * 2 355*da0073e9SAndroid Build Coastguard Worker weight.requires_grad_() 356*da0073e9SAndroid Build Coastguard Worker inp_list = [0, 1, 2, 2] 357*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor(inp_list, device=device) 358*da0073e9SAndroid Build Coastguard Worker out = nn.functional.embedding(inp, weight, max_norm=1.0).sum() 359*da0073e9SAndroid Build Coastguard Worker out.backward() 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker expected_grad = ( 362*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1.0, 1.0, 2.0, 0.0]], device=device, dtype=dtype) 363*da0073e9SAndroid Build Coastguard Worker .transpose(0, 1) 364*da0073e9SAndroid Build Coastguard Worker .expand(4, 4) 365*da0073e9SAndroid Build Coastguard Worker ) 366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.grad, expected_grad) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 369*da0073e9SAndroid Build Coastguard Worker *( 370*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.bfloat16, torch.half) 371*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM 372*da0073e9SAndroid Build Coastguard Worker else (torch.float, torch.double, torch.half) 373*da0073e9SAndroid Build Coastguard Worker ) 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 376*da0073e9SAndroid Build Coastguard Worker def test_embedding_max_norm_fwd_AD(self, device, dtype): 377*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == "xla": 378*da0073e9SAndroid Build Coastguard Worker self.skipTest("forward AD doesn't work on xla") 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker # can't use gradcheck since in place renorm makes analytical gradients different from produced ones 381*da0073e9SAndroid Build Coastguard Worker weight = torch.randn((4, 4), device=device, dtype=dtype) * 2 382*da0073e9SAndroid Build Coastguard Worker tangent = torch.ones((4, 4), device=device, dtype=dtype) 383*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([[0, 1], [2, 2]], device=device) 384*da0073e9SAndroid Build Coastguard Worker with torch.autograd.forward_ad.dual_level(): 385*da0073e9SAndroid Build Coastguard Worker dual_weight = torch.autograd.forward_ad.make_dual(weight, tangent) 386*da0073e9SAndroid Build Coastguard Worker out = nn.functional.embedding(inp, dual_weight, max_norm=1.0) 387*da0073e9SAndroid Build Coastguard Worker jvp = torch.autograd.forward_ad.unpack_dual(out).tangent 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.ones((2, 2, 4), device=device, dtype=dtype) 390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jvp, expected_grad) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 393*da0073e9SAndroid Build Coastguard Worker *( 394*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.bfloat16, torch.half) 395*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM 396*da0073e9SAndroid Build Coastguard Worker else (torch.float, torch.double, torch.half) 397*da0073e9SAndroid Build Coastguard Worker ) 398*da0073e9SAndroid Build Coastguard Worker ) 399*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 400*da0073e9SAndroid Build Coastguard Worker def test_embedding_padding_idx(self, device, dtype): 401*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 20, padding_idx=0).to(device, dtype) 402*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device) 403*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0][0].sum(), 0) 405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1][2].sum(), 0) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True).to(device, dtype) 408*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device) 409*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0][0].sum(), 0) 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1][2].sum(), 0) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker # negative indexing check for padding_idx 414*da0073e9SAndroid Build Coastguard Worker # padding_idx=-2, num_embeddings=10 ==> index 8 padded 415*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 20, padding_idx=-2).to(device, dtype) 416*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device) 417*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0][2].sum(), 0) 419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1][1].sum(), 0) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True).to(device, dtype) 422*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device) 423*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0][2].sum(), 0) 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1][1].sum(), 0) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker # change padding vector 428*da0073e9SAndroid Build Coastguard Worker padding_vector = torch.ones(20, dtype=dtype, device=device) 429*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(10, 20, padding_idx=2, sparse=True).to(device, dtype) 430*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 431*da0073e9SAndroid Build Coastguard Worker embedding.weight[2] = padding_vector 432*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([0, 2], dtype=torch.long).to(device) 433*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1], padding_vector) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker # out of bounds check for padding_idx 437*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 438*da0073e9SAndroid Build Coastguard Worker AssertionError, 439*da0073e9SAndroid Build Coastguard Worker nn.Embedding, 440*da0073e9SAndroid Build Coastguard Worker num_embeddings=10, 441*da0073e9SAndroid Build Coastguard Worker embedding_dim=20, 442*da0073e9SAndroid Build Coastguard Worker padding_idx=25, 443*da0073e9SAndroid Build Coastguard Worker ) 444*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 445*da0073e9SAndroid Build Coastguard Worker AssertionError, 446*da0073e9SAndroid Build Coastguard Worker nn.Embedding, 447*da0073e9SAndroid Build Coastguard Worker num_embeddings=10, 448*da0073e9SAndroid Build Coastguard Worker embedding_dim=20, 449*da0073e9SAndroid Build Coastguard Worker padding_idx=-25, 450*da0073e9SAndroid Build Coastguard Worker ) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker padding_idx = 0 453*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(5, 2, padding_idx=padding_idx).to(device, dtype) 454*da0073e9SAndroid Build Coastguard Worker for n in ( 455*da0073e9SAndroid Build Coastguard Worker 1, 456*da0073e9SAndroid Build Coastguard Worker 2, 457*da0073e9SAndroid Build Coastguard Worker 1000, 458*da0073e9SAndroid Build Coastguard Worker ): # Need large N to trigger all the methods we have implemented 459*da0073e9SAndroid Build Coastguard Worker for other_indices in ([], [1, 3], [2]): 460*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor( 461*da0073e9SAndroid Build Coastguard Worker other_indices + [padding_idx] * n, dtype=torch.long 462*da0073e9SAndroid Build Coastguard Worker ).to(device) 463*da0073e9SAndroid Build Coastguard Worker pre = embedding.weight[padding_idx].clone() 464*da0073e9SAndroid Build Coastguard Worker embedding(indices).sum().backward() 465*da0073e9SAndroid Build Coastguard Worker after = (embedding.weight + embedding.weight.grad)[padding_idx] 466*da0073e9SAndroid Build Coastguard Worker embedding.zero_grad() 467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(after, pre) 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker # test double backward 470*da0073e9SAndroid Build Coastguard Worker emb_sum = embedding(indices).sum() 471*da0073e9SAndroid Build Coastguard Worker emb_grad = torch.autograd.grad( 472*da0073e9SAndroid Build Coastguard Worker outputs=emb_sum, 473*da0073e9SAndroid Build Coastguard Worker inputs=list(embedding.parameters()), 474*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker scalar = emb_grad[0].sum() + emb_sum 477*da0073e9SAndroid Build Coastguard Worker scalar.backward() 478*da0073e9SAndroid Build Coastguard Worker after = (embedding.weight + embedding.weight.grad)[padding_idx] 479*da0073e9SAndroid Build Coastguard Worker embedding.zero_grad() 480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(after, pre) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker # Check correctness of torch.nn.functional.embedding_bag forward and 483*da0073e9SAndroid Build Coastguard Worker # backward functions with padding_idx, given a 1D input separated into bags 484*da0073e9SAndroid Build Coastguard Worker # with an offset array. Compare against an equivalent 2D input that uses 485*da0073e9SAndroid Build Coastguard Worker # padding indices to fill in the gaps indicated by the offset array 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("see https://github.com/pytorch/pytorch/pull/95621") 488*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 489*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 490*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.bfloat16) 491*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_1D_padding_idx(self, device, dtype): 492*da0073e9SAndroid Build Coastguard Worker num_features = 3 493*da0073e9SAndroid Build Coastguard Worker max_indices_per_bag = 10 494*da0073e9SAndroid Build Coastguard Worker num_bags = 10 495*da0073e9SAndroid Build Coastguard Worker num_words = 100 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker def gen_1D_indices_offsets(include_last_offset, allpad): 498*da0073e9SAndroid Build Coastguard Worker indices = [] 499*da0073e9SAndroid Build Coastguard Worker offsets = [] 500*da0073e9SAndroid Build Coastguard Worker cur_offset = 0 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker # Make one bag full and one bag empty, for extra coverage 503*da0073e9SAndroid Build Coastguard Worker empty_bag = random.randint(0, num_bags - 1) 504*da0073e9SAndroid Build Coastguard Worker full_bag = empty_bag 505*da0073e9SAndroid Build Coastguard Worker while full_bag == empty_bag: 506*da0073e9SAndroid Build Coastguard Worker full_bag = random.randint(0, num_bags - 1) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker for bag in range(num_bags): 509*da0073e9SAndroid Build Coastguard Worker offsets.append(cur_offset) 510*da0073e9SAndroid Build Coastguard Worker if bag == full_bag: 511*da0073e9SAndroid Build Coastguard Worker bag_size = max_indices_per_bag 512*da0073e9SAndroid Build Coastguard Worker elif bag == empty_bag: 513*da0073e9SAndroid Build Coastguard Worker bag_size = 0 514*da0073e9SAndroid Build Coastguard Worker else: 515*da0073e9SAndroid Build Coastguard Worker bag_size = random.randint(1, max_indices_per_bag - 1) 516*da0073e9SAndroid Build Coastguard Worker indices += [ 517*da0073e9SAndroid Build Coastguard Worker 1 if allpad else random.randint(0, num_words - 1) 518*da0073e9SAndroid Build Coastguard Worker for _ in range(bag_size) 519*da0073e9SAndroid Build Coastguard Worker ] 520*da0073e9SAndroid Build Coastguard Worker cur_offset += bag_size 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker # embedding_bag requires first entry of offsets to be 0 523*da0073e9SAndroid Build Coastguard Worker assert offsets[0] == 0 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(indices, device=device) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker if include_last_offset: 528*da0073e9SAndroid Build Coastguard Worker offsets.append(indices.size(0)) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor(offsets, device=device) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker return indices, offsets 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker # Convert a 1-D indices-offsets representation into 2-D. Fill any empty 535*da0073e9SAndroid Build Coastguard Worker # indices with padding_idx 536*da0073e9SAndroid Build Coastguard Worker def gen_2D_indices_from_1D( 537*da0073e9SAndroid Build Coastguard Worker indices_1D, offsets, include_last_offset, padding_idx 538*da0073e9SAndroid Build Coastguard Worker ): 539*da0073e9SAndroid Build Coastguard Worker assert offsets[0] == 0 540*da0073e9SAndroid Build Coastguard Worker if include_last_offset: 541*da0073e9SAndroid Build Coastguard Worker offsets = offsets[:-1] 542*da0073e9SAndroid Build Coastguard Worker indices_2D = torch.empty( 543*da0073e9SAndroid Build Coastguard Worker num_bags, max_indices_per_bag, device=device, dtype=torch.long 544*da0073e9SAndroid Build Coastguard Worker ) 545*da0073e9SAndroid Build Coastguard Worker for bag in range(num_bags): 546*da0073e9SAndroid Build Coastguard Worker # Determine the start and end position of the bag within indices_1D 547*da0073e9SAndroid Build Coastguard Worker start = offsets[bag] 548*da0073e9SAndroid Build Coastguard Worker end = len(indices_1D) if bag + 1 == num_bags else offsets[bag + 1] 549*da0073e9SAndroid Build Coastguard Worker end = min(len(indices_1D), end) 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker # Pull out the bag's indices from indices_1D, and fill any 552*da0073e9SAndroid Build Coastguard Worker # remaining space with padding indices 553*da0073e9SAndroid Build Coastguard Worker indices_in_bag = [] 554*da0073e9SAndroid Build Coastguard Worker for item_pos in range(0, max_indices_per_bag): 555*da0073e9SAndroid Build Coastguard Worker if (start + item_pos) < end: 556*da0073e9SAndroid Build Coastguard Worker indices_in_bag.append(indices_1D[start + item_pos]) 557*da0073e9SAndroid Build Coastguard Worker else: 558*da0073e9SAndroid Build Coastguard Worker indices_in_bag.append(padding_idx) 559*da0073e9SAndroid Build Coastguard Worker indices_2D[bag] = torch.tensor(indices_in_bag, device=device) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker return indices_2D 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker test_cases = product( 564*da0073e9SAndroid Build Coastguard Worker ["max", "mean", "sum"], [False, True], [False, True], [False, True] 565*da0073e9SAndroid Build Coastguard Worker ) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker for mode, sparse, include_last_offset, allpad in test_cases: 568*da0073e9SAndroid Build Coastguard Worker # Max sparse and bfloat16 are not supported 569*da0073e9SAndroid Build Coastguard Worker if mode == "max": 570*da0073e9SAndroid Build Coastguard Worker if sparse or (dtype == torch.bfloat16): 571*da0073e9SAndroid Build Coastguard Worker continue 572*da0073e9SAndroid Build Coastguard Worker indices_1D, offsets = gen_1D_indices_offsets(include_last_offset, allpad) 573*da0073e9SAndroid Build Coastguard Worker for padding_idx_1D in list(set(indices_1D.tolist())) + [None]: 574*da0073e9SAndroid Build Coastguard Worker msg = ( 575*da0073e9SAndroid Build Coastguard Worker f"mode: '{mode}', sparse: {sparse}, include_last_offset: {include_last_offset}, " 576*da0073e9SAndroid Build Coastguard Worker f"padding_idx_1D: {padding_idx_1D}" 577*da0073e9SAndroid Build Coastguard Worker ) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker # If 1D input does not use a padding index, we still need one for the 2D input, 580*da0073e9SAndroid Build Coastguard Worker # so we can add one dummy word to the weights to act as the padded word 581*da0073e9SAndroid Build Coastguard Worker padding_idx_2D = ( 582*da0073e9SAndroid Build Coastguard Worker padding_idx_1D if padding_idx_1D is not None else num_words 583*da0073e9SAndroid Build Coastguard Worker ) 584*da0073e9SAndroid Build Coastguard Worker num_words_with_padding = ( 585*da0073e9SAndroid Build Coastguard Worker num_words if padding_idx_1D is not None else num_words + 1 586*da0073e9SAndroid Build Coastguard Worker ) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker indices_2D = gen_2D_indices_from_1D( 589*da0073e9SAndroid Build Coastguard Worker indices_1D, offsets, include_last_offset, padding_idx_2D 590*da0073e9SAndroid Build Coastguard Worker ) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker weights = torch.randn( 593*da0073e9SAndroid Build Coastguard Worker num_words_with_padding, 594*da0073e9SAndroid Build Coastguard Worker num_features, 595*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 596*da0073e9SAndroid Build Coastguard Worker device=device, 597*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 598*da0073e9SAndroid Build Coastguard Worker ) 599*da0073e9SAndroid Build Coastguard Worker weights_check = weights.clone().detach().requires_grad_(True) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker bag = torch.nn.functional.embedding_bag( 602*da0073e9SAndroid Build Coastguard Worker indices_1D, 603*da0073e9SAndroid Build Coastguard Worker weights, 604*da0073e9SAndroid Build Coastguard Worker offsets, 605*da0073e9SAndroid Build Coastguard Worker padding_idx=padding_idx_1D, 606*da0073e9SAndroid Build Coastguard Worker mode=mode, 607*da0073e9SAndroid Build Coastguard Worker sparse=sparse, 608*da0073e9SAndroid Build Coastguard Worker include_last_offset=include_last_offset, 609*da0073e9SAndroid Build Coastguard Worker ) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker bag_check = torch.nn.functional.embedding_bag( 612*da0073e9SAndroid Build Coastguard Worker indices_2D, 613*da0073e9SAndroid Build Coastguard Worker weights_check, 614*da0073e9SAndroid Build Coastguard Worker padding_idx=padding_idx_2D, 615*da0073e9SAndroid Build Coastguard Worker mode=mode, 616*da0073e9SAndroid Build Coastguard Worker sparse=sparse, 617*da0073e9SAndroid Build Coastguard Worker ) 618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bag, bag_check, msg=msg) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker bag.sum().backward() 621*da0073e9SAndroid Build Coastguard Worker bag_check.sum().backward() 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker # Sometimes, half dtype gradients mismatch by a greater amount 624*da0073e9SAndroid Build Coastguard Worker # than other dtypes 625*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.half, torch.bfloat16]: 626*da0073e9SAndroid Build Coastguard Worker atol = 0.01 627*da0073e9SAndroid Build Coastguard Worker rtol = 0.01 628*da0073e9SAndroid Build Coastguard Worker else: 629*da0073e9SAndroid Build Coastguard Worker atol = None 630*da0073e9SAndroid Build Coastguard Worker rtol = None 631*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 632*da0073e9SAndroid Build Coastguard Worker weights.grad, weights_check.grad, msg=msg, atol=atol, rtol=rtol 633*da0073e9SAndroid Build Coastguard Worker ) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker # Check correctness of torch.nn.functional.embedding_bag forward and 636*da0073e9SAndroid Build Coastguard Worker # backward functions with padding_idx, given a 2D indices input. Compare 637*da0073e9SAndroid Build Coastguard Worker # against torch.nn.functional.embedding followed by a reduction. 638*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 639*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 640*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.bfloat16) 641*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_2D_padding_idx(self, device, dtype): 642*da0073e9SAndroid Build Coastguard Worker # Use a Python implementation of embedding_bag with padding_idx support 643*da0073e9SAndroid Build Coastguard Worker # to check torch.nn.functional.embedding_bag correctness 644*da0073e9SAndroid Build Coastguard Worker def embedding_bag_check(indices, weights, mode, sparse, padding_idx): 645*da0073e9SAndroid Build Coastguard Worker assert padding_idx is not None 646*da0073e9SAndroid Build Coastguard Worker embedding = torch.nn.functional.embedding( 647*da0073e9SAndroid Build Coastguard Worker indices, weights, padding_idx=padding_idx, sparse=sparse 648*da0073e9SAndroid Build Coastguard Worker ) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker reduction_dim = indices.dim() - 1 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker if mode == "sum" or mode == "mean": 653*da0073e9SAndroid Build Coastguard Worker # We must avoid including elements at padding_idx in the 654*da0073e9SAndroid Build Coastguard Worker # sum/mean, so multiply those elements by 0, and multiply 655*da0073e9SAndroid Build Coastguard Worker # all other elements by 1 656*da0073e9SAndroid Build Coastguard Worker per_sample_weights = indices.ne(padding_idx).to(dtype).unsqueeze(-1) 657*da0073e9SAndroid Build Coastguard Worker res = embedding.mul(per_sample_weights).sum(dim=reduction_dim) 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker if mode == "mean": 660*da0073e9SAndroid Build Coastguard Worker weights_sum = per_sample_weights.sum(dim=reduction_dim) 661*da0073e9SAndroid Build Coastguard Worker res = res.div(weights_sum) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker elif mode == "max": 664*da0073e9SAndroid Build Coastguard Worker # We must avoid allowing elements at padding_idx to be chosen 665*da0073e9SAndroid Build Coastguard Worker # as the max, so set those elements to negative infinity 666*da0073e9SAndroid Build Coastguard Worker res = embedding.masked_fill( 667*da0073e9SAndroid Build Coastguard Worker indices.unsqueeze(-1) == padding_idx, -float("inf") 668*da0073e9SAndroid Build Coastguard Worker ).amax(dim=reduction_dim) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker else: 671*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"mode '{mode}' is not available") 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker # If a row is all padding, set its corresponding result row to 0. 674*da0073e9SAndroid Build Coastguard Worker # This is needed because the above mean and max mode 675*da0073e9SAndroid Build Coastguard Worker # implementations set these elements to nan and -inf, respectively 676*da0073e9SAndroid Build Coastguard Worker if mode in ["mean", "max"]: 677*da0073e9SAndroid Build Coastguard Worker res = res.masked_fill( 678*da0073e9SAndroid Build Coastguard Worker indices.eq(padding_idx).all(dim=-1).unsqueeze(-1), 0 679*da0073e9SAndroid Build Coastguard Worker ) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker return res 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker num_features = 3 684*da0073e9SAndroid Build Coastguard Worker num_words = 10 685*da0073e9SAndroid Build Coastguard Worker indices_dim1 = 10 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker for mode, sparse, allpad, indices_dim0 in product( 688*da0073e9SAndroid Build Coastguard Worker ["max", "mean", "sum"], [False, True], [False, True], [1, 10] 689*da0073e9SAndroid Build Coastguard Worker ): 690*da0073e9SAndroid Build Coastguard Worker # Max sparse and bfloat16 are not supported 691*da0073e9SAndroid Build Coastguard Worker if mode == "max": 692*da0073e9SAndroid Build Coastguard Worker if sparse or (dtype == torch.bfloat16): 693*da0073e9SAndroid Build Coastguard Worker continue 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker if allpad: 696*da0073e9SAndroid Build Coastguard Worker indices = torch.empty( 697*da0073e9SAndroid Build Coastguard Worker indices_dim0, indices_dim1, dtype=torch.long, device=device 698*da0073e9SAndroid Build Coastguard Worker ).fill_(1) 699*da0073e9SAndroid Build Coastguard Worker else: 700*da0073e9SAndroid Build Coastguard Worker indices = torch.randint( 701*da0073e9SAndroid Build Coastguard Worker 0, num_words, (indices_dim0, indices_dim1), device=device 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker if indices_dim0 > 1: 705*da0073e9SAndroid Build Coastguard Worker # Fill one row with duplicate index so we can test with a fully 706*da0073e9SAndroid Build Coastguard Worker # padded row 707*da0073e9SAndroid Build Coastguard Worker duplicate_row = random.randint(0, indices_dim0 - 1) 708*da0073e9SAndroid Build Coastguard Worker indices[duplicate_row] = indices[duplicate_row][0] 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker for padding_idx in list(set(indices.flatten(0, -1).tolist())): 711*da0073e9SAndroid Build Coastguard Worker weights = torch.randn( 712*da0073e9SAndroid Build Coastguard Worker num_words, 713*da0073e9SAndroid Build Coastguard Worker num_features, 714*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 715*da0073e9SAndroid Build Coastguard Worker device=device, 716*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 717*da0073e9SAndroid Build Coastguard Worker ) 718*da0073e9SAndroid Build Coastguard Worker weights_check = weights.clone().detach().requires_grad_(True) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker msg = ( 721*da0073e9SAndroid Build Coastguard Worker f"mode: '{mode}', sparse: {sparse}, padding_idx: {padding_idx}, " 722*da0073e9SAndroid Build Coastguard Worker f"allpad: {allpad}, indices.size(): {indices.size()}" 723*da0073e9SAndroid Build Coastguard Worker ) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker # Check forward with a Python implementation of padding_idx embedding_bag 726*da0073e9SAndroid Build Coastguard Worker bag_check = embedding_bag_check( 727*da0073e9SAndroid Build Coastguard Worker indices, weights_check, mode, sparse, padding_idx 728*da0073e9SAndroid Build Coastguard Worker ) 729*da0073e9SAndroid Build Coastguard Worker bag = torch.nn.functional.embedding_bag( 730*da0073e9SAndroid Build Coastguard Worker indices, weights, padding_idx=padding_idx, mode=mode, sparse=sparse 731*da0073e9SAndroid Build Coastguard Worker ) 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bag, bag_check, msg=msg) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker bag_check.sum().backward() 736*da0073e9SAndroid Build Coastguard Worker grad_check = weights_check.grad 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker bag.sum().backward() 739*da0073e9SAndroid Build Coastguard Worker grad = weights.grad 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker # Sometimes, half dtype gradients mismatch by a greater amount 742*da0073e9SAndroid Build Coastguard Worker # than other dtypes 743*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.half, torch.bfloat16]: 744*da0073e9SAndroid Build Coastguard Worker atol = 0.01 745*da0073e9SAndroid Build Coastguard Worker rtol = 0.01 746*da0073e9SAndroid Build Coastguard Worker else: 747*da0073e9SAndroid Build Coastguard Worker atol = None 748*da0073e9SAndroid Build Coastguard Worker rtol = None 749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, grad_check, msg=msg, atol=atol, rtol=rtol) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 752*da0073e9SAndroid Build Coastguard Worker @dtypes( 753*da0073e9SAndroid Build Coastguard Worker *( 754*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.bfloat16, torch.half) 755*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM 756*da0073e9SAndroid Build Coastguard Worker else (torch.float, torch.double, torch.half) 757*da0073e9SAndroid Build Coastguard Worker ) 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker def test_embedding_max_norm_device(self, device, dtype): 760*da0073e9SAndroid Build Coastguard Worker embedding = nn.Embedding(22, 5, max_norm=1.0).to(device, dtype=dtype) 761*da0073e9SAndroid Build Coastguard Worker # nn.Embedding only takes LongTensor as input 762*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([2, 8, 8, 6], device=device, dtype=torch.long) 763*da0073e9SAndroid Build Coastguard Worker output = embedding(input) 764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[1], output[2]) 765*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 768*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_empty_input(self, device, dtypes): 769*da0073e9SAndroid Build Coastguard Worker m = 4 770*da0073e9SAndroid Build Coastguard Worker n = 3 771*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([], device=device, dtype=dtypes[0]) 772*da0073e9SAndroid Build Coastguard Worker for sparse in [True, False]: 773*da0073e9SAndroid Build Coastguard Worker Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) 774*da0073e9SAndroid Build Coastguard Worker Embed.to(device) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker output = Embed( 777*da0073e9SAndroid Build Coastguard Worker input=x, offsets=torch.tensor([0], device=device, dtype=dtypes[1]) 778*da0073e9SAndroid Build Coastguard Worker ) 779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, torch.zeros_like(output)) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker output = Embed( 782*da0073e9SAndroid Build Coastguard Worker input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtypes[1]) 783*da0073e9SAndroid Build Coastguard Worker ) 784*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, torch.zeros_like(output)) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker @skipCUDAIf(True, "no out-of-bounds check on CUDA for perf.") 787*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product((torch.float, torch.double), (torch.int, torch.long))) 788*da0073e9SAndroid Build Coastguard Worker @parametrize_test("padding_idx", [None, 0]) 789*da0073e9SAndroid Build Coastguard Worker @parametrize_test("mode", ["sum", "mean", "max"]) 790*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_out_of_bounds_idx(self, device, dtypes, padding_idx, mode): 791*da0073e9SAndroid Build Coastguard Worker padding_idx = 0 792*da0073e9SAndroid Build Coastguard Worker w_dtype, idx_dtype = dtypes 793*da0073e9SAndroid Build Coastguard Worker # negative out-of-bound 794*da0073e9SAndroid Build Coastguard Worker idx1 = torch.tensor([[-1, 1]], device=device, dtype=idx_dtype) 795*da0073e9SAndroid Build Coastguard Worker # positive out-of-bound 796*da0073e9SAndroid Build Coastguard Worker idx2 = torch.tensor([[11, 8]], device=device, dtype=idx_dtype) 797*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(10, 2, device=device, dtype=w_dtype) 798*da0073e9SAndroid Build Coastguard Worker if mode == "sum": 799*da0073e9SAndroid Build Coastguard Worker # Only `sum` supports per_sample_weight 800*da0073e9SAndroid Build Coastguard Worker per_sample_weights = ( 801*da0073e9SAndroid Build Coastguard Worker None, 802*da0073e9SAndroid Build Coastguard Worker torch.randn_like(idx1, device=device, dtype=w_dtype), 803*da0073e9SAndroid Build Coastguard Worker ) 804*da0073e9SAndroid Build Coastguard Worker else: 805*da0073e9SAndroid Build Coastguard Worker per_sample_weights = (None,) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker for p_s_weights, idx in itertools.product(per_sample_weights, (idx1, idx2)): 808*da0073e9SAndroid Build Coastguard Worker msg = "Expected idx >= 0 && idx < num_embeddings" 809*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 810*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.embedding_bag( 811*da0073e9SAndroid Build Coastguard Worker idx, 812*da0073e9SAndroid Build Coastguard Worker weight, 813*da0073e9SAndroid Build Coastguard Worker per_sample_weights=p_s_weights, 814*da0073e9SAndroid Build Coastguard Worker padding_idx=padding_idx, 815*da0073e9SAndroid Build Coastguard Worker mode=mode, 816*da0073e9SAndroid Build Coastguard Worker ) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_dimension_errors(self, device): 819*da0073e9SAndroid Build Coastguard Worker funcs = ( 820*da0073e9SAndroid Build Coastguard Worker lambda x, y, z: torch.nn.functional.embedding_bag(y, x, z), 821*da0073e9SAndroid Build Coastguard Worker torch.embedding_bag, 822*da0073e9SAndroid Build Coastguard Worker torch._embedding_bag, 823*da0073e9SAndroid Build Coastguard Worker torch._embedding_bag_forward_only, 824*da0073e9SAndroid Build Coastguard Worker ) 825*da0073e9SAndroid Build Coastguard Worker for i, f in enumerate(funcs): 826*da0073e9SAndroid Build Coastguard Worker err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker weight = torch.full( 829*da0073e9SAndroid Build Coastguard Worker ( 830*da0073e9SAndroid Build Coastguard Worker 2, 831*da0073e9SAndroid Build Coastguard Worker 6, 832*da0073e9SAndroid Build Coastguard Worker ), 833*da0073e9SAndroid Build Coastguard Worker 0, 834*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, 835*da0073e9SAndroid Build Coastguard Worker device=device, 836*da0073e9SAndroid Build Coastguard Worker ) 837*da0073e9SAndroid Build Coastguard Worker indices = torch.full( 838*da0073e9SAndroid Build Coastguard Worker ( 839*da0073e9SAndroid Build Coastguard Worker 2, 840*da0073e9SAndroid Build Coastguard Worker 0, 841*da0073e9SAndroid Build Coastguard Worker 0, 842*da0073e9SAndroid Build Coastguard Worker 6, 843*da0073e9SAndroid Build Coastguard Worker 6, 844*da0073e9SAndroid Build Coastguard Worker ), 845*da0073e9SAndroid Build Coastguard Worker 2, 846*da0073e9SAndroid Build Coastguard Worker dtype=torch.int64, 847*da0073e9SAndroid Build Coastguard Worker device=device, 848*da0073e9SAndroid Build Coastguard Worker ) 849*da0073e9SAndroid Build Coastguard Worker offsets = torch.full((2, 0, 0, 6, 6), 0, dtype=torch.int64, device=device) 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker if i == 0: 852*da0073e9SAndroid Build Coastguard Worker error_msg = "input has to be 1D or 2D Tensor" 853*da0073e9SAndroid Build Coastguard Worker else: 854*da0073e9SAndroid Build Coastguard Worker error_msg = "input has to be a 1D or 2D Tensor" 855*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaisesRegex)( 856*da0073e9SAndroid Build Coastguard Worker err_type, error_msg, lambda: f(weight, indices, offsets) 857*da0073e9SAndroid Build Coastguard Worker ) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker weight = torch.full((2, 2), 0, dtype=torch.float64, device=device) 860*da0073e9SAndroid Build Coastguard Worker indices = torch.full((2,), 1, dtype=torch.int64, device=device) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaisesRegex)( 863*da0073e9SAndroid Build Coastguard Worker err_type, 864*da0073e9SAndroid Build Coastguard Worker "offsets has to be a 1D Tensor", 865*da0073e9SAndroid Build Coastguard Worker lambda: f(weight, indices, offsets), 866*da0073e9SAndroid Build Coastguard Worker ) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker weight = torch.full((2, 2, 2), 0, dtype=torch.float64, device=device) 869*da0073e9SAndroid Build Coastguard Worker indices = torch.full((2,), 2, dtype=torch.int64, device=device) 870*da0073e9SAndroid Build Coastguard Worker offsets = torch.full((2,), 0, dtype=torch.int64, device=device) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaisesRegex)( 873*da0073e9SAndroid Build Coastguard Worker err_type, 874*da0073e9SAndroid Build Coastguard Worker "weight has to be a 2D Tensor", 875*da0073e9SAndroid Build Coastguard Worker lambda: f(weight, indices, offsets), 876*da0073e9SAndroid Build Coastguard Worker ) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 879*da0073e9SAndroid Build Coastguard Worker def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes): 880*da0073e9SAndroid Build Coastguard Worker # Failure 1: mismatched embeddings / per_sample_weights dtype 881*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(5, 2, mode="sum").to(dtype=torch.float, device=device) 882*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device) 883*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device) 884*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device) 885*da0073e9SAndroid Build Coastguard Worker if device == "cpu": 886*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "have the same type as"): 887*da0073e9SAndroid Build Coastguard Worker es(input, offsets, per_sample_weights) 888*da0073e9SAndroid Build Coastguard Worker else: 889*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected scalar type"): 890*da0073e9SAndroid Build Coastguard Worker es(input, offsets, per_sample_weights) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker # Failure 2.1: input/per_sample_weights have different sizes (1d input) 893*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device) 894*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device) 895*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn(5, dtype=torch.float, device=device) 896*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "same shape as the input"): 897*da0073e9SAndroid Build Coastguard Worker es(input, offsets, per_sample_weights) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker # Failure 2.2: input/per_sample_weights have different sizes (2d input) 900*da0073e9SAndroid Build Coastguard Worker input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device) 901*da0073e9SAndroid Build Coastguard Worker offsets = None 902*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device) 903*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "same shape as the input"): 904*da0073e9SAndroid Build Coastguard Worker es(input, offsets, per_sample_weights) 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker # Failure 3: Unsupported per_sample_weights and mode=('max', 'mean') 907*da0073e9SAndroid Build Coastguard Worker for unsupported_mode in ("max", "mean"): 908*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to( 909*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, device=device 910*da0073e9SAndroid Build Coastguard Worker ) 911*da0073e9SAndroid Build Coastguard Worker input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device) 912*da0073e9SAndroid Build Coastguard Worker offsets = None 913*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device) 914*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 915*da0073e9SAndroid Build Coastguard Worker NotImplementedError, "only supported for mode='sum'" 916*da0073e9SAndroid Build Coastguard Worker ): 917*da0073e9SAndroid Build Coastguard Worker es(input, offsets, per_sample_weights) 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker def _embedding_bag_reference_impl( 920*da0073e9SAndroid Build Coastguard Worker self, 921*da0073e9SAndroid Build Coastguard Worker input, 922*da0073e9SAndroid Build Coastguard Worker weight, 923*da0073e9SAndroid Build Coastguard Worker offsets=None, 924*da0073e9SAndroid Build Coastguard Worker mode="sum", 925*da0073e9SAndroid Build Coastguard Worker per_sample_weights=None, 926*da0073e9SAndroid Build Coastguard Worker include_last_offset=False, 927*da0073e9SAndroid Build Coastguard Worker ): 928*da0073e9SAndroid Build Coastguard Worker assert mode == "sum" or per_sample_weights is None 929*da0073e9SAndroid Build Coastguard Worker assert offsets is not None 930*da0073e9SAndroid Build Coastguard Worker if per_sample_weights is None: 931*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.ones(input.size()).to( 932*da0073e9SAndroid Build Coastguard Worker dtype=weight.dtype, device=weight.device 933*da0073e9SAndroid Build Coastguard Worker ) 934*da0073e9SAndroid Build Coastguard Worker assert input.numel() == per_sample_weights.numel() 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker bags = [] 937*da0073e9SAndroid Build Coastguard Worker long_input = input.to(torch.long) 938*da0073e9SAndroid Build Coastguard Worker embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze( 939*da0073e9SAndroid Build Coastguard Worker 1 940*da0073e9SAndroid Build Coastguard Worker ) 941*da0073e9SAndroid Build Coastguard Worker if include_last_offset: 942*da0073e9SAndroid Build Coastguard Worker for index in range(len(offsets) - 1): 943*da0073e9SAndroid Build Coastguard Worker offset = offsets[index] 944*da0073e9SAndroid Build Coastguard Worker next_offset = offsets[index + 1] 945*da0073e9SAndroid Build Coastguard Worker length = next_offset - offset 946*da0073e9SAndroid Build Coastguard Worker if length == 0: 947*da0073e9SAndroid Build Coastguard Worker bags.append( 948*da0073e9SAndroid Build Coastguard Worker torch.tensor([0] * weight.size(1)).to( 949*da0073e9SAndroid Build Coastguard Worker dtype=embeddings.dtype, device=embeddings.device 950*da0073e9SAndroid Build Coastguard Worker ) 951*da0073e9SAndroid Build Coastguard Worker ) 952*da0073e9SAndroid Build Coastguard Worker else: 953*da0073e9SAndroid Build Coastguard Worker if mode == "sum": 954*da0073e9SAndroid Build Coastguard Worker bags.append(embeddings.narrow(0, offset, length).sum(0)) 955*da0073e9SAndroid Build Coastguard Worker elif mode == "mean": 956*da0073e9SAndroid Build Coastguard Worker bags.append( 957*da0073e9SAndroid Build Coastguard Worker embeddings.narrow(0, offset, length).sum(0).div(length) 958*da0073e9SAndroid Build Coastguard Worker ) 959*da0073e9SAndroid Build Coastguard Worker else: 960*da0073e9SAndroid Build Coastguard Worker assert mode == "max" 961*da0073e9SAndroid Build Coastguard Worker bags.append(embeddings.narrow(0, offset, length).max(0)[0]) 962*da0073e9SAndroid Build Coastguard Worker else: 963*da0073e9SAndroid Build Coastguard Worker for index, offset in enumerate(offsets): 964*da0073e9SAndroid Build Coastguard Worker if index + 1 < len(offsets): 965*da0073e9SAndroid Build Coastguard Worker next_offset = offsets[index + 1] 966*da0073e9SAndroid Build Coastguard Worker else: 967*da0073e9SAndroid Build Coastguard Worker next_offset = len(long_input) 968*da0073e9SAndroid Build Coastguard Worker length = next_offset - offset 969*da0073e9SAndroid Build Coastguard Worker if length == 0: 970*da0073e9SAndroid Build Coastguard Worker bags.append( 971*da0073e9SAndroid Build Coastguard Worker torch.tensor([0] * weight.size(1)).to( 972*da0073e9SAndroid Build Coastguard Worker dtype=embeddings.dtype, device=embeddings.device 973*da0073e9SAndroid Build Coastguard Worker ) 974*da0073e9SAndroid Build Coastguard Worker ) 975*da0073e9SAndroid Build Coastguard Worker else: 976*da0073e9SAndroid Build Coastguard Worker if mode == "sum": 977*da0073e9SAndroid Build Coastguard Worker bags.append(embeddings.narrow(0, offset, length).sum(0)) 978*da0073e9SAndroid Build Coastguard Worker elif mode == "mean": 979*da0073e9SAndroid Build Coastguard Worker bags.append( 980*da0073e9SAndroid Build Coastguard Worker embeddings.narrow(0, offset, length).sum(0).div(length) 981*da0073e9SAndroid Build Coastguard Worker ) 982*da0073e9SAndroid Build Coastguard Worker else: 983*da0073e9SAndroid Build Coastguard Worker assert mode == "max" 984*da0073e9SAndroid Build Coastguard Worker bags.append(embeddings.narrow(0, offset, length).max(0)[0]) 985*da0073e9SAndroid Build Coastguard Worker return torch.stack(bags) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker @skipMeta 988*da0073e9SAndroid Build Coastguard Worker @dtypes( 989*da0073e9SAndroid Build Coastguard Worker *itertools.product( 990*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 991*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 992*da0073e9SAndroid Build Coastguard Worker (torch.half, torch.bfloat16, torch.float, torch.double), 993*da0073e9SAndroid Build Coastguard Worker ) 994*da0073e9SAndroid Build Coastguard Worker ) 995*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 996*da0073e9SAndroid Build Coastguard Worker *itertools.product( 997*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 998*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 999*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half), 1000*da0073e9SAndroid Build Coastguard Worker ) 1001*da0073e9SAndroid Build Coastguard Worker ) 1002*da0073e9SAndroid Build Coastguard Worker def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes): 1003*da0073e9SAndroid Build Coastguard Worker # Test empty input and per sample weight, and backward pass. There was a CUDA 1004*da0073e9SAndroid Build Coastguard Worker # invalid configuration bug (more context in #46572) 1005*da0073e9SAndroid Build Coastguard Worker def test_per_sample_weights(mode, trainable_scale): 1006*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device) 1007*da0073e9SAndroid Build Coastguard Worker es.weight.data.copy_( 1008*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2]) 1009*da0073e9SAndroid Build Coastguard Worker ) 1010*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([], device=device, dtype=dtypes[0]) 1011*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[1]) 1012*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn_like( 1013*da0073e9SAndroid Build Coastguard Worker input, dtype=dtypes[2] 1014*da0073e9SAndroid Build Coastguard Worker ).requires_grad_(trainable_scale) 1015*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights = per_sample_weights.detach().requires_grad_( 1016*da0073e9SAndroid Build Coastguard Worker trainable_scale 1017*da0073e9SAndroid Build Coastguard Worker ) 1018*da0073e9SAndroid Build Coastguard Worker reference_weights = es.weight.detach().requires_grad_() 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker expected = self._embedding_bag_reference_impl( 1021*da0073e9SAndroid Build Coastguard Worker input, reference_weights, offsets, mode, ref_per_sample_weights 1022*da0073e9SAndroid Build Coastguard Worker ) 1023*da0073e9SAndroid Build Coastguard Worker result = es(input, offsets, per_sample_weights) 1024*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1025*da0073e9SAndroid Build Coastguard Worker result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0 1026*da0073e9SAndroid Build Coastguard Worker ) 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(expected) 1029*da0073e9SAndroid Build Coastguard Worker result.backward(grad) 1030*da0073e9SAndroid Build Coastguard Worker # the reference impl doesn't have grad fn for empty input; but the grad should 1031*da0073e9SAndroid Build Coastguard Worker # simply be a zero tensor 1032*da0073e9SAndroid Build Coastguard Worker ref_weights_grad = torch.zeros_like(es.weight) 1033*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1034*da0073e9SAndroid Build Coastguard Worker es.weight.grad, 1035*da0073e9SAndroid Build Coastguard Worker ref_weights_grad, 1036*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtypes[2]], 1037*da0073e9SAndroid Build Coastguard Worker rtol=0, 1038*da0073e9SAndroid Build Coastguard Worker ) 1039*da0073e9SAndroid Build Coastguard Worker if trainable_scale: 1040*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights_grad = torch.empty_like(per_sample_weights) 1041*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1042*da0073e9SAndroid Build Coastguard Worker per_sample_weights.grad, 1043*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights_grad, 1044*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtypes[2]], 1045*da0073e9SAndroid Build Coastguard Worker rtol=0, 1046*da0073e9SAndroid Build Coastguard Worker ) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker modes = ("sum",) 1049*da0073e9SAndroid Build Coastguard Worker trainable_scale = (True, False) 1050*da0073e9SAndroid Build Coastguard Worker for mode, trainable in itertools.product(modes, trainable_scale): 1051*da0073e9SAndroid Build Coastguard Worker test_per_sample_weights(mode, trainable) 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker @skipMeta 1054*da0073e9SAndroid Build Coastguard Worker @dtypes( 1055*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1056*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1057*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1058*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half, torch.bfloat16), 1059*da0073e9SAndroid Build Coastguard Worker ) 1060*da0073e9SAndroid Build Coastguard Worker ) 1061*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 1062*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1063*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1064*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1065*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half), 1066*da0073e9SAndroid Build Coastguard Worker ) 1067*da0073e9SAndroid Build Coastguard Worker ) 1068*da0073e9SAndroid Build Coastguard Worker def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes): 1069*da0073e9SAndroid Build Coastguard Worker def test_per_sample_weights(mode, trainable_scale): 1070*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device) 1071*da0073e9SAndroid Build Coastguard Worker es.weight.data.copy_( 1072*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2]) 1073*da0073e9SAndroid Build Coastguard Worker ) 1074*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) 1075*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1]) 1076*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn_like( 1077*da0073e9SAndroid Build Coastguard Worker input, dtype=dtypes[2] 1078*da0073e9SAndroid Build Coastguard Worker ).requires_grad_(trainable_scale) 1079*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights = per_sample_weights.detach().requires_grad_( 1080*da0073e9SAndroid Build Coastguard Worker trainable_scale 1081*da0073e9SAndroid Build Coastguard Worker ) 1082*da0073e9SAndroid Build Coastguard Worker reference_weights = es.weight.detach().requires_grad_() 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker expected = self._embedding_bag_reference_impl( 1085*da0073e9SAndroid Build Coastguard Worker input, reference_weights, offsets, mode, ref_per_sample_weights 1086*da0073e9SAndroid Build Coastguard Worker ) 1087*da0073e9SAndroid Build Coastguard Worker result = es(input, offsets, per_sample_weights) 1088*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1089*da0073e9SAndroid Build Coastguard Worker result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0 1090*da0073e9SAndroid Build Coastguard Worker ) 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(expected).to(dtype=dtypes[2], device=device) 1093*da0073e9SAndroid Build Coastguard Worker result.backward(grad) 1094*da0073e9SAndroid Build Coastguard Worker expected.backward(grad) 1095*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1096*da0073e9SAndroid Build Coastguard Worker es.weight.grad, 1097*da0073e9SAndroid Build Coastguard Worker reference_weights.grad, 1098*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtypes[2]], 1099*da0073e9SAndroid Build Coastguard Worker rtol=0, 1100*da0073e9SAndroid Build Coastguard Worker ) 1101*da0073e9SAndroid Build Coastguard Worker if trainable_scale: 1102*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1103*da0073e9SAndroid Build Coastguard Worker per_sample_weights.grad, 1104*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights.grad, 1105*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtypes[2]], 1106*da0073e9SAndroid Build Coastguard Worker rtol=0, 1107*da0073e9SAndroid Build Coastguard Worker ) 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker modes = ("sum",) 1110*da0073e9SAndroid Build Coastguard Worker trainable_scale = (True, False) 1111*da0073e9SAndroid Build Coastguard Worker for mode, trainable in itertools.product(modes, trainable_scale): 1112*da0073e9SAndroid Build Coastguard Worker test_per_sample_weights(mode, trainable) 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker @skipMeta 1115*da0073e9SAndroid Build Coastguard Worker @dtypes( 1116*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1117*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1118*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1119*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half, torch.bfloat16), 1120*da0073e9SAndroid Build Coastguard Worker ) 1121*da0073e9SAndroid Build Coastguard Worker ) 1122*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 1123*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1124*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1125*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1126*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half), 1127*da0073e9SAndroid Build Coastguard Worker ) 1128*da0073e9SAndroid Build Coastguard Worker ) 1129*da0073e9SAndroid Build Coastguard Worker def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes): 1130*da0073e9SAndroid Build Coastguard Worker def test_per_sample_weights_new_offsets( 1131*da0073e9SAndroid Build Coastguard Worker mode, trainable_scale, include_last_offset, has_weight=True 1132*da0073e9SAndroid Build Coastguard Worker ): 1133*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag( 1134*da0073e9SAndroid Build Coastguard Worker 5, 2, mode=mode, include_last_offset=include_last_offset 1135*da0073e9SAndroid Build Coastguard Worker ).to(dtype=dtypes[2], device=device) 1136*da0073e9SAndroid Build Coastguard Worker es.weight.data.copy_( 1137*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 11, device=device).view_as(es.weight).to(dtypes[2]) 1138*da0073e9SAndroid Build Coastguard Worker ) 1139*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0]) 1140*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1]) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker if include_last_offset: 1143*da0073e9SAndroid Build Coastguard Worker offsets = torch.cat( 1144*da0073e9SAndroid Build Coastguard Worker ( 1145*da0073e9SAndroid Build Coastguard Worker offsets, 1146*da0073e9SAndroid Build Coastguard Worker torch.tensor([input.size(0)], device=device, dtype=dtypes[1]), 1147*da0073e9SAndroid Build Coastguard Worker ), 1148*da0073e9SAndroid Build Coastguard Worker 0, 1149*da0073e9SAndroid Build Coastguard Worker ) 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker if has_weight: 1152*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn_like( 1153*da0073e9SAndroid Build Coastguard Worker input, device=device, dtype=dtypes[2] 1154*da0073e9SAndroid Build Coastguard Worker ).requires_grad_(trainable_scale) 1155*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights = per_sample_weights.detach().requires_grad_( 1156*da0073e9SAndroid Build Coastguard Worker trainable_scale 1157*da0073e9SAndroid Build Coastguard Worker ) 1158*da0073e9SAndroid Build Coastguard Worker else: 1159*da0073e9SAndroid Build Coastguard Worker per_sample_weights = None 1160*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights = None 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker reference_weights = es.weight.detach().requires_grad_() 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker expected = self._embedding_bag_reference_impl( 1165*da0073e9SAndroid Build Coastguard Worker input, 1166*da0073e9SAndroid Build Coastguard Worker reference_weights, 1167*da0073e9SAndroid Build Coastguard Worker offsets, 1168*da0073e9SAndroid Build Coastguard Worker mode, 1169*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights, 1170*da0073e9SAndroid Build Coastguard Worker include_last_offset, 1171*da0073e9SAndroid Build Coastguard Worker ) 1172*da0073e9SAndroid Build Coastguard Worker result = es(input, offsets, per_sample_weights) 1173*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1174*da0073e9SAndroid Build Coastguard Worker result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0 1175*da0073e9SAndroid Build Coastguard Worker ) 1176*da0073e9SAndroid Build Coastguard Worker 1177*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(expected) 1178*da0073e9SAndroid Build Coastguard Worker result.backward(grad) 1179*da0073e9SAndroid Build Coastguard Worker expected.backward(grad) 1180*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1181*da0073e9SAndroid Build Coastguard Worker es.weight.grad, 1182*da0073e9SAndroid Build Coastguard Worker reference_weights.grad, 1183*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtypes[2]], 1184*da0073e9SAndroid Build Coastguard Worker rtol=0, 1185*da0073e9SAndroid Build Coastguard Worker ) 1186*da0073e9SAndroid Build Coastguard Worker if has_weight and trainable_scale: 1187*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1188*da0073e9SAndroid Build Coastguard Worker per_sample_weights.grad, 1189*da0073e9SAndroid Build Coastguard Worker ref_per_sample_weights.grad, 1190*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[dtypes[2]], 1191*da0073e9SAndroid Build Coastguard Worker rtol=0, 1192*da0073e9SAndroid Build Coastguard Worker ) 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker trainable_scale = (True, False) 1195*da0073e9SAndroid Build Coastguard Worker include_last_offset_list = (True, False) 1196*da0073e9SAndroid Build Coastguard Worker modes = (("sum", False), ("sum", True), ("max", False), ("mean", False)) 1197*da0073e9SAndroid Build Coastguard Worker for (mode, has_weight), trainable, include_last_offset in itertools.product( 1198*da0073e9SAndroid Build Coastguard Worker modes, trainable_scale, include_last_offset_list 1199*da0073e9SAndroid Build Coastguard Worker ): 1200*da0073e9SAndroid Build Coastguard Worker test_per_sample_weights_new_offsets( 1201*da0073e9SAndroid Build Coastguard Worker mode, trainable, include_last_offset, has_weight 1202*da0073e9SAndroid Build Coastguard Worker ) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker def _test_EmbeddingBag_vs_Embedding( 1205*da0073e9SAndroid Build Coastguard Worker self, 1206*da0073e9SAndroid Build Coastguard Worker N, 1207*da0073e9SAndroid Build Coastguard Worker D, 1208*da0073e9SAndroid Build Coastguard Worker B, 1209*da0073e9SAndroid Build Coastguard Worker L, 1210*da0073e9SAndroid Build Coastguard Worker max_norm=None, 1211*da0073e9SAndroid Build Coastguard Worker mode="mean", 1212*da0073e9SAndroid Build Coastguard Worker device="cpu", 1213*da0073e9SAndroid Build Coastguard Worker wdtype=torch.float, 1214*da0073e9SAndroid Build Coastguard Worker dtype=torch.long, 1215*da0073e9SAndroid Build Coastguard Worker test_per_sample_weights=False, 1216*da0073e9SAndroid Build Coastguard Worker trainable_per_sample_weights=False, 1217*da0073e9SAndroid Build Coastguard Worker sparse=False, 1218*da0073e9SAndroid Build Coastguard Worker test_backward=True, 1219*da0073e9SAndroid Build Coastguard Worker backward_prec=None, 1220*da0073e9SAndroid Build Coastguard Worker ): 1221*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to( 1222*da0073e9SAndroid Build Coastguard Worker device, wdtype 1223*da0073e9SAndroid Build Coastguard Worker ) 1224*da0073e9SAndroid Build Coastguard Worker e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype) 1225*da0073e9SAndroid Build Coastguard Worker e.weight.data.copy_(es.weight) 1226*da0073e9SAndroid Build Coastguard Worker input = torch.randint(N, (B, L), device=device, dtype=dtype) 1227*da0073e9SAndroid Build Coastguard Worker offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L) 1228*da0073e9SAndroid Build Coastguard Worker grad_output = torch.rand(B, D, device=device, dtype=wdtype) 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker if test_per_sample_weights: 1231*da0073e9SAndroid Build Coastguard Worker # To prevent large gradients, weights should sum to 1 for each bag 1232*da0073e9SAndroid Build Coastguard Worker per_sample_weights = torch.randn(B, L, device=device, dtype=wdtype).softmax( 1233*da0073e9SAndroid Build Coastguard Worker dim=-1 1234*da0073e9SAndroid Build Coastguard Worker ) 1235*da0073e9SAndroid Build Coastguard Worker per_sample_weights_reference = per_sample_weights.clone().requires_grad_( 1236*da0073e9SAndroid Build Coastguard Worker trainable_per_sample_weights 1237*da0073e9SAndroid Build Coastguard Worker ) 1238*da0073e9SAndroid Build Coastguard Worker per_sample_weights.requires_grad_(trainable_per_sample_weights) 1239*da0073e9SAndroid Build Coastguard Worker output = es(input.view(-1), offsets, per_sample_weights.view(-1)) 1240*da0073e9SAndroid Build Coastguard Worker else: 1241*da0073e9SAndroid Build Coastguard Worker output = es(input.view(-1), offsets) 1242*da0073e9SAndroid Build Coastguard Worker per_sample_weights = None 1243*da0073e9SAndroid Build Coastguard Worker per_sample_weights_reference = None 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker if mode == "sum": 1246*da0073e9SAndroid Build Coastguard Worker if test_per_sample_weights: 1247*da0073e9SAndroid Build Coastguard Worker ref_output = ( 1248*da0073e9SAndroid Build Coastguard Worker e(input) * per_sample_weights_reference.unsqueeze(-1) 1249*da0073e9SAndroid Build Coastguard Worker ).sum(1) 1250*da0073e9SAndroid Build Coastguard Worker else: 1251*da0073e9SAndroid Build Coastguard Worker ref_output = e(input).sum(1) 1252*da0073e9SAndroid Build Coastguard Worker elif mode == "mean": 1253*da0073e9SAndroid Build Coastguard Worker assert not test_per_sample_weights 1254*da0073e9SAndroid Build Coastguard Worker ref_output = e(input).mean(1) 1255*da0073e9SAndroid Build Coastguard Worker elif mode == "max": 1256*da0073e9SAndroid Build Coastguard Worker assert not test_per_sample_weights 1257*da0073e9SAndroid Build Coastguard Worker ref_output = e(input).max(1)[0] 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0) 1260*da0073e9SAndroid Build Coastguard Worker 1261*da0073e9SAndroid Build Coastguard Worker if not test_backward: 1262*da0073e9SAndroid Build Coastguard Worker return 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 1265*da0073e9SAndroid Build Coastguard Worker ref_output.backward(grad_output) 1266*da0073e9SAndroid Build Coastguard Worker es_weight_grad = es.weight.grad 1267*da0073e9SAndroid Build Coastguard Worker if sparse: 1268*da0073e9SAndroid Build Coastguard Worker es_weight_grad = es.weight.grad.to_dense() 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker # We have more floating point error here because we are dealing with larger numbers 1271*da0073e9SAndroid Build Coastguard Worker if backward_prec is None: 1272*da0073e9SAndroid Build Coastguard Worker needed_prec = dtype2prec_DONTUSE[wdtype] * 5 1273*da0073e9SAndroid Build Coastguard Worker rtol = 0.02 if wdtype == torch.half else 0 1274*da0073e9SAndroid Build Coastguard Worker else: 1275*da0073e9SAndroid Build Coastguard Worker needed_prec = backward_prec 1276*da0073e9SAndroid Build Coastguard Worker rtol = 0 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(es_weight_grad, e.weight.grad, atol=needed_prec, rtol=rtol) 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker if test_per_sample_weights and trainable_per_sample_weights: 1281*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1282*da0073e9SAndroid Build Coastguard Worker per_sample_weights.grad, 1283*da0073e9SAndroid Build Coastguard Worker per_sample_weights_reference.grad, 1284*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[wdtype], 1285*da0073e9SAndroid Build Coastguard Worker rtol=0, 1286*da0073e9SAndroid Build Coastguard Worker ) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 1289*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1290*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), (torch.half, torch.float, torch.double) 1291*da0073e9SAndroid Build Coastguard Worker ) 1292*da0073e9SAndroid Build Coastguard Worker ) 1293*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double))) 1294*da0073e9SAndroid Build Coastguard Worker def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes): 1295*da0073e9SAndroid Build Coastguard Worker def run_tests(mode, sparse, trainable_per_sample_weights): 1296*da0073e9SAndroid Build Coastguard Worker kwargs = dict( 1297*da0073e9SAndroid Build Coastguard Worker test_per_sample_weights=True, 1298*da0073e9SAndroid Build Coastguard Worker device=device, 1299*da0073e9SAndroid Build Coastguard Worker mode=mode, 1300*da0073e9SAndroid Build Coastguard Worker wdtype=dtypes[1], 1301*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1302*da0073e9SAndroid Build Coastguard Worker sparse=sparse, 1303*da0073e9SAndroid Build Coastguard Worker trainable_per_sample_weights=trainable_per_sample_weights, 1304*da0073e9SAndroid Build Coastguard Worker ) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker # Simple case 1307*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag_vs_Embedding(2, 3, 5, 7, **kwargs) 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker # B * L > 1000 1310*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag_vs_Embedding(2, 5, 53, 23, **kwargs) 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker # Large num_embedding 1313*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag_vs_Embedding(101, 5, 3, 7, **kwargs) 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker # Large embedding_dim 1316*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs) 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker modes = ("sum",) 1319*da0073e9SAndroid Build Coastguard Worker sparsity = (True, False) 1320*da0073e9SAndroid Build Coastguard Worker trainable_scale = (True, False) 1321*da0073e9SAndroid Build Coastguard Worker for mode, sparse, trainable_per_sample_weights in itertools.product( 1322*da0073e9SAndroid Build Coastguard Worker modes, sparsity, trainable_scale 1323*da0073e9SAndroid Build Coastguard Worker ): 1324*da0073e9SAndroid Build Coastguard Worker run_tests(mode, sparse, trainable_per_sample_weights) 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker # Test CUDA Dense on half precision 1327*da0073e9SAndroid Build Coastguard Worker if device == "cuda": 1328*da0073e9SAndroid Build Coastguard Worker modes = ("sum",) 1329*da0073e9SAndroid Build Coastguard Worker sparsity = (False,) 1330*da0073e9SAndroid Build Coastguard Worker trainable_scale = (True, False) 1331*da0073e9SAndroid Build Coastguard Worker for mode, sparse, trainable_per_sample_weights in itertools.product( 1332*da0073e9SAndroid Build Coastguard Worker modes, sparsity, trainable_scale 1333*da0073e9SAndroid Build Coastguard Worker ): 1334*da0073e9SAndroid Build Coastguard Worker run_tests(mode, sparse, trainable_per_sample_weights) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker def _test_EmbeddingBag( 1337*da0073e9SAndroid Build Coastguard Worker self, 1338*da0073e9SAndroid Build Coastguard Worker device, 1339*da0073e9SAndroid Build Coastguard Worker mode, 1340*da0073e9SAndroid Build Coastguard Worker sparse, 1341*da0073e9SAndroid Build Coastguard Worker wdtype=torch.double, 1342*da0073e9SAndroid Build Coastguard Worker dtype=torch.long, 1343*da0073e9SAndroid Build Coastguard Worker odtype=torch.long, 1344*da0073e9SAndroid Build Coastguard Worker test_backward=True, 1345*da0073e9SAndroid Build Coastguard Worker ): 1346*da0073e9SAndroid Build Coastguard Worker # check a known test example 1347*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype) 1348*da0073e9SAndroid Build Coastguard Worker es.weight.data.copy_( 1349*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 11, device=device).view_as(es.weight).to(wdtype) 1350*da0073e9SAndroid Build Coastguard Worker ) 1351*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype) 1352*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=odtype) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker grad_output = torch.tensor([1, 2, 3, 4], device=device, dtype=wdtype).view(2, 2) 1355*da0073e9SAndroid Build Coastguard Worker grad_output_with_empty = torch.tensor( 1356*da0073e9SAndroid Build Coastguard Worker [99, 99, 1, 2, 99, 99, 3, 4, 99, 99], device=device, dtype=wdtype 1357*da0073e9SAndroid Build Coastguard Worker ).view(5, 2) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker if mode == "sum" or mode == "mean": 1360*da0073e9SAndroid Build Coastguard Worker denominator = 1 if mode == "sum" else 3 1361*da0073e9SAndroid Build Coastguard Worker expected_output = ( 1362*da0073e9SAndroid Build Coastguard Worker torch.tensor([[13, 16], [13, 16]], device=device, dtype=wdtype) 1363*da0073e9SAndroid Build Coastguard Worker / denominator 1364*da0073e9SAndroid Build Coastguard Worker ) 1365*da0073e9SAndroid Build Coastguard Worker 1366*da0073e9SAndroid Build Coastguard Worker expected_output_with_empty = ( 1367*da0073e9SAndroid Build Coastguard Worker torch.tensor( 1368*da0073e9SAndroid Build Coastguard Worker [[0, 0], [13, 16], [0, 0], [13, 16], [0, 0]], 1369*da0073e9SAndroid Build Coastguard Worker device=device, 1370*da0073e9SAndroid Build Coastguard Worker dtype=wdtype, 1371*da0073e9SAndroid Build Coastguard Worker ) 1372*da0073e9SAndroid Build Coastguard Worker / denominator 1373*da0073e9SAndroid Build Coastguard Worker ) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker expected_grad_weight = ( 1376*da0073e9SAndroid Build Coastguard Worker torch.tensor( 1377*da0073e9SAndroid Build Coastguard Worker [[3, 4], [5, 8], [0, 0], [1, 2], [3, 4]], 1378*da0073e9SAndroid Build Coastguard Worker device=device, 1379*da0073e9SAndroid Build Coastguard Worker dtype=wdtype, 1380*da0073e9SAndroid Build Coastguard Worker ) 1381*da0073e9SAndroid Build Coastguard Worker / denominator 1382*da0073e9SAndroid Build Coastguard Worker ) 1383*da0073e9SAndroid Build Coastguard Worker elif mode == "max": 1384*da0073e9SAndroid Build Coastguard Worker expected_output = torch.tensor( 1385*da0073e9SAndroid Build Coastguard Worker [[7, 8], [9, 10]], device=device, dtype=wdtype 1386*da0073e9SAndroid Build Coastguard Worker ) 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker expected_output_with_empty = torch.tensor( 1389*da0073e9SAndroid Build Coastguard Worker [[0, 0], [7, 8], [0, 0], [9, 10], [0, 0]], device=device, dtype=wdtype 1390*da0073e9SAndroid Build Coastguard Worker ) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker expected_grad_weight = torch.tensor( 1393*da0073e9SAndroid Build Coastguard Worker [[0, 0], [0, 0], [0, 0], [1, 2], [3, 4]], device=device, dtype=wdtype 1394*da0073e9SAndroid Build Coastguard Worker ) 1395*da0073e9SAndroid Build Coastguard Worker output = es(input, offsets) 1396*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output_with_empty) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker es_weight_grad = es.weight.grad 1399*da0073e9SAndroid Build Coastguard Worker if sparse: 1400*da0073e9SAndroid Build Coastguard Worker es_weight_grad = es.weight.grad.to_dense() 1401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, expected_output_with_empty) 1402*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1403*da0073e9SAndroid Build Coastguard Worker es_weight_grad, 1404*da0073e9SAndroid Build Coastguard Worker expected_grad_weight, 1405*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[wdtype], 1406*da0073e9SAndroid Build Coastguard Worker rtol=0, 1407*da0073e9SAndroid Build Coastguard Worker ) 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker # check same example except as 2D (2 x 3) 1410*da0073e9SAndroid Build Coastguard Worker input = input.view(2, -1) 1411*da0073e9SAndroid Build Coastguard Worker es.zero_grad() 1412*da0073e9SAndroid Build Coastguard Worker output = es(input) 1413*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output) 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker es_weight_grad = es.weight.grad 1416*da0073e9SAndroid Build Coastguard Worker if sparse: 1417*da0073e9SAndroid Build Coastguard Worker es_weight_grad = es.weight.grad.to_dense() 1418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, expected_output) 1419*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1420*da0073e9SAndroid Build Coastguard Worker es_weight_grad, 1421*da0073e9SAndroid Build Coastguard Worker expected_grad_weight, 1422*da0073e9SAndroid Build Coastguard Worker atol=dtype2prec_DONTUSE[wdtype], 1423*da0073e9SAndroid Build Coastguard Worker rtol=0, 1424*da0073e9SAndroid Build Coastguard Worker ) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker # test all empty bags 1427*da0073e9SAndroid Build Coastguard Worker es.zero_grad() 1428*da0073e9SAndroid Build Coastguard Worker inputs = torch.tensor([], dtype=dtype, device=device) 1429*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 0, 0, 0], dtype=odtype, device=device) 1430*da0073e9SAndroid Build Coastguard Worker es(inputs, offsets).sum().backward() 1431*da0073e9SAndroid Build Coastguard Worker dense_grad = es.weight.grad 1432*da0073e9SAndroid Build Coastguard Worker if dense_grad.is_sparse: 1433*da0073e9SAndroid Build Coastguard Worker dense_grad = dense_grad.to_dense() 1434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_grad, torch.zeros_like(es.weight)) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length 1437*da0073e9SAndroid Build Coastguard Worker N, D, B, L = ( 1438*da0073e9SAndroid Build Coastguard Worker random.randint(1, 100), 1439*da0073e9SAndroid Build Coastguard Worker random.randint(1, 100), 1440*da0073e9SAndroid Build Coastguard Worker random.randint(1, 50), 1441*da0073e9SAndroid Build Coastguard Worker random.randint(1, 50), 1442*da0073e9SAndroid Build Coastguard Worker ) 1443*da0073e9SAndroid Build Coastguard Worker kwargs = dict( 1444*da0073e9SAndroid Build Coastguard Worker mode=mode, 1445*da0073e9SAndroid Build Coastguard Worker sparse=sparse, 1446*da0073e9SAndroid Build Coastguard Worker device=device, 1447*da0073e9SAndroid Build Coastguard Worker wdtype=wdtype, 1448*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1449*da0073e9SAndroid Build Coastguard Worker test_backward=test_backward, 1450*da0073e9SAndroid Build Coastguard Worker ) 1451*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs) 1452*da0073e9SAndroid Build Coastguard Worker for max_norm in (None, 3): 1453*da0073e9SAndroid Build Coastguard Worker for p in itertools.product([1, 2], repeat=4): 1454*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag_vs_Embedding(*p, max_norm=max_norm, **kwargs) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker # check that giving illegal input combos raises error 1457*da0073e9SAndroid Build Coastguard Worker es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse) 1458*da0073e9SAndroid Build Coastguard Worker input = torch.ones(3, 4, dtype=dtype) 1459*da0073e9SAndroid Build Coastguard Worker offset = torch.arange(0, 3, dtype=odtype) 1460*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input, offset)) 1461*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaises)(ValueError, lambda: es(input.view(-1))) 1462*da0073e9SAndroid Build Coastguard Worker offset[0] = 1 1463*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 1464*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaises)( 1465*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: es(input.view(-1), offset) 1466*da0073e9SAndroid Build Coastguard Worker ) 1467*da0073e9SAndroid Build Coastguard Worker offset[0] = 0 1468*da0073e9SAndroid Build Coastguard Worker offset[-1] = 100 1469*da0073e9SAndroid Build Coastguard Worker torch._dynamo.disable(self.assertRaises)( 1470*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: es(input.view(-1), offset) 1471*da0073e9SAndroid Build Coastguard Worker ) 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker @skipMeta 1474*da0073e9SAndroid Build Coastguard Worker @dtypes( 1475*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1476*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1477*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1478*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half, torch.bfloat16), 1479*da0073e9SAndroid Build Coastguard Worker ) 1480*da0073e9SAndroid Build Coastguard Worker ) 1481*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 1482*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1483*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1484*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1485*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half), 1486*da0073e9SAndroid Build Coastguard Worker ) 1487*da0073e9SAndroid Build Coastguard Worker ) 1488*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_device(self, device, dtypes): 1489*da0073e9SAndroid Build Coastguard Worker if IS_JETSON and torch.bfloat16 in dtypes and device == "cpu": 1490*da0073e9SAndroid Build Coastguard Worker self.skipTest("bfloat16 not supported with Jetson cpu") 1491*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1492*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1493*da0073e9SAndroid Build Coastguard Worker device, 1494*da0073e9SAndroid Build Coastguard Worker "sum", 1495*da0073e9SAndroid Build Coastguard Worker False, 1496*da0073e9SAndroid Build Coastguard Worker wdtype=dtypes[2], 1497*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1498*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1499*da0073e9SAndroid Build Coastguard Worker ) 1500*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1501*da0073e9SAndroid Build Coastguard Worker device, 1502*da0073e9SAndroid Build Coastguard Worker "mean", 1503*da0073e9SAndroid Build Coastguard Worker False, 1504*da0073e9SAndroid Build Coastguard Worker wdtype=dtypes[2], 1505*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1506*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1507*da0073e9SAndroid Build Coastguard Worker ) 1508*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1509*da0073e9SAndroid Build Coastguard Worker device, 1510*da0073e9SAndroid Build Coastguard Worker "max", 1511*da0073e9SAndroid Build Coastguard Worker False, 1512*da0073e9SAndroid Build Coastguard Worker wdtype=dtypes[2], 1513*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1514*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1515*da0073e9SAndroid Build Coastguard Worker ) 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker test_backward = False 1518*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cuda": 1519*da0073e9SAndroid Build Coastguard Worker # see 'todo' in test_embedding_bag. 1520*da0073e9SAndroid Build Coastguard Worker test_backward = dtypes[2] is not torch.float16 1521*da0073e9SAndroid Build Coastguard Worker elif self.device_type == "cpu": 1522*da0073e9SAndroid Build Coastguard Worker # TODO: figure out why precision on sparse embeddings isn't the 1523*da0073e9SAndroid Build Coastguard Worker # same as for dense. 1524*da0073e9SAndroid Build Coastguard Worker test_backward = ( 1525*da0073e9SAndroid Build Coastguard Worker dtypes[2] is not torch.float and dtypes[2] is not torch.float16 1526*da0073e9SAndroid Build Coastguard Worker ) 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1529*da0073e9SAndroid Build Coastguard Worker device, 1530*da0073e9SAndroid Build Coastguard Worker "sum", 1531*da0073e9SAndroid Build Coastguard Worker True, 1532*da0073e9SAndroid Build Coastguard Worker wdtype=dtypes[2], 1533*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1534*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1535*da0073e9SAndroid Build Coastguard Worker test_backward=test_backward, 1536*da0073e9SAndroid Build Coastguard Worker ) 1537*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1538*da0073e9SAndroid Build Coastguard Worker device, 1539*da0073e9SAndroid Build Coastguard Worker "mean", 1540*da0073e9SAndroid Build Coastguard Worker True, 1541*da0073e9SAndroid Build Coastguard Worker wdtype=dtypes[2], 1542*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1543*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1544*da0073e9SAndroid Build Coastguard Worker test_backward=test_backward, 1545*da0073e9SAndroid Build Coastguard Worker ) 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker @skipMeta 1548*da0073e9SAndroid Build Coastguard Worker @dtypes( 1549*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1550*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1551*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1552*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half, torch.bfloat16), 1553*da0073e9SAndroid Build Coastguard Worker ) 1554*da0073e9SAndroid Build Coastguard Worker ) 1555*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 1556*da0073e9SAndroid Build Coastguard Worker *itertools.product( 1557*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1558*da0073e9SAndroid Build Coastguard Worker (torch.int, torch.long), 1559*da0073e9SAndroid Build Coastguard Worker (torch.float, torch.double, torch.half), 1560*da0073e9SAndroid Build Coastguard Worker ) 1561*da0073e9SAndroid Build Coastguard Worker ) 1562*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_non_contiguous_weight(self, device, dtypes): 1563*da0073e9SAndroid Build Coastguard Worker weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device) 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker weight_tensor_non_contig = weight_tensor[ 1566*da0073e9SAndroid Build Coastguard Worker :, :3 1567*da0073e9SAndroid Build Coastguard Worker ] # This is non-contiguous strided. 1568*da0073e9SAndroid Build Coastguard Worker weight_tensor_contig = ( 1569*da0073e9SAndroid Build Coastguard Worker weight_tensor_non_contig.clone().contiguous() 1570*da0073e9SAndroid Build Coastguard Worker ) # Contig-strided. 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device) 1573*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2], dtype=dtypes[1], device=device) 1574*da0073e9SAndroid Build Coastguard Worker for mode in ["sum", "mean", "max"]: 1575*da0073e9SAndroid Build Coastguard Worker output_non_contig = F.embedding_bag( 1576*da0073e9SAndroid Build Coastguard Worker input=index, 1577*da0073e9SAndroid Build Coastguard Worker weight=weight_tensor_non_contig, 1578*da0073e9SAndroid Build Coastguard Worker offsets=offsets, 1579*da0073e9SAndroid Build Coastguard Worker mode=mode, 1580*da0073e9SAndroid Build Coastguard Worker ) 1581*da0073e9SAndroid Build Coastguard Worker output_contig = F.embedding_bag( 1582*da0073e9SAndroid Build Coastguard Worker input=index, 1583*da0073e9SAndroid Build Coastguard Worker weight=weight_tensor_contig, 1584*da0073e9SAndroid Build Coastguard Worker offsets=offsets, 1585*da0073e9SAndroid Build Coastguard Worker mode=mode, 1586*da0073e9SAndroid Build Coastguard Worker ) 1587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_non_contig, output_contig) 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # currently fails on XLA 1590*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 1591*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_bfloat16(self, device, dtypes): 1592*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1593*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1594*da0073e9SAndroid Build Coastguard Worker device, 1595*da0073e9SAndroid Build Coastguard Worker "sum", 1596*da0073e9SAndroid Build Coastguard Worker True, 1597*da0073e9SAndroid Build Coastguard Worker wdtype=torch.bfloat16, 1598*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1599*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1600*da0073e9SAndroid Build Coastguard Worker test_backward=True, 1601*da0073e9SAndroid Build Coastguard Worker ) 1602*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1603*da0073e9SAndroid Build Coastguard Worker device, 1604*da0073e9SAndroid Build Coastguard Worker "mean", 1605*da0073e9SAndroid Build Coastguard Worker True, 1606*da0073e9SAndroid Build Coastguard Worker wdtype=torch.bfloat16, 1607*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1608*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1609*da0073e9SAndroid Build Coastguard Worker test_backward=True, 1610*da0073e9SAndroid Build Coastguard Worker ) 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # currently fails on XLA 1613*da0073e9SAndroid Build Coastguard Worker @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) 1614*da0073e9SAndroid Build Coastguard Worker def test_embedding_bag_half(self, device, dtypes): 1615*da0073e9SAndroid Build Coastguard Worker self._test_EmbeddingBag( 1616*da0073e9SAndroid Build Coastguard Worker device, 1617*da0073e9SAndroid Build Coastguard Worker "sum", 1618*da0073e9SAndroid Build Coastguard Worker True, 1619*da0073e9SAndroid Build Coastguard Worker wdtype=torch.float16, 1620*da0073e9SAndroid Build Coastguard Worker dtype=dtypes[0], 1621*da0073e9SAndroid Build Coastguard Worker odtype=dtypes[1], 1622*da0073e9SAndroid Build Coastguard Worker test_backward=True, 1623*da0073e9SAndroid Build Coastguard Worker ) 1624*da0073e9SAndroid Build Coastguard Worker 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestEmbeddingNNDeviceType, globals()) 1627*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestEmbeddingNN) 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1630*da0073e9SAndroid Build Coastguard Worker run_tests() 1631