1# Owner(s): ["module: dynamo"] 2import sys 3import unittest 4from typing import Dict, List 5 6import torch 7import torch._dynamo.config 8import torch._dynamo.test_case 9from torch import nn 10from torch._dynamo.test_case import TestCase 11from torch._dynamo.testing import CompileCounter 12from torch.testing._internal.common_utils import NoTest 13 14 15try: 16 from torchrec.datasets.random import RandomRecDataset 17 from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor 18 19 HAS_TORCHREC = True 20except ImportError: 21 HAS_TORCHREC = False 22 23 24@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True) 25class BucketizeMod(torch.nn.Module): 26 def __init__(self, feature_boundaries: Dict[str, List[float]]): 27 super().__init__() 28 self.bucket_w = torch.nn.ParameterDict() 29 self.boundaries_dict = {} 30 for key, boundaries in feature_boundaries.items(): 31 self.bucket_w[key] = torch.nn.Parameter( 32 torch.empty([len(boundaries) + 1]).fill_(1.0), 33 requires_grad=True, 34 ) 35 buf = torch.tensor(boundaries, requires_grad=False) 36 self.register_buffer( 37 f"{key}_boundaries", 38 buf, 39 persistent=False, 40 ) 41 self.boundaries_dict[key] = buf 42 43 def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor": 44 weights_list = [] 45 for key, boundaries in self.boundaries_dict.items(): 46 jt = features[key] 47 bucketized = torch.bucketize(jt.weights(), boundaries) 48 # doesn't super matter I guess 49 # hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries)) 50 hashed = bucketized 51 weights = torch.gather(self.bucket_w[key], dim=0, index=hashed) 52 weights_list.append(weights) 53 return KeyedJaggedTensor( 54 keys=features.keys(), 55 values=features.values(), 56 weights=torch.cat(weights_list), 57 lengths=features.lengths(), 58 offsets=features.offsets(), 59 stride=features.stride(), 60 length_per_key=features.length_per_key(), 61 ) 62 63 64if not HAS_TORCHREC: 65 print("torchrec not available, skipping tests", file=sys.stderr) 66 TestCase = NoTest # noqa: F811 67 68 69@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec") 70class TorchRecTests(TestCase): 71 def test_pooled(self): 72 tables = [ 73 (nn.EmbeddingBag(2000, 8), ["a0", "b0"]), 74 (nn.EmbeddingBag(2000, 8), ["a1", "b1"]), 75 (nn.EmbeddingBag(2000, 8), ["b2"]), 76 ] 77 78 embedding_groups = { 79 "a": ["a0", "a1"], 80 "b": ["b0", "b1", "b2"], 81 } 82 83 counter = CompileCounter() 84 85 @torch.compile(backend=counter, fullgraph=True, dynamic=True) 86 def f(id_list_features: KeyedJaggedTensor): 87 id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict() 88 pooled_embeddings = {} 89 # TODO: run feature processor 90 for emb_module, feature_names in tables: 91 features_dict = id_list_jt_dict 92 for feature_name in feature_names: 93 f = features_dict[feature_name] 94 pooled_embeddings[feature_name] = emb_module( 95 f.values(), f.offsets() 96 ) 97 98 pooled_embeddings_by_group = {} 99 for group_name, group_embedding_names in embedding_groups.items(): 100 group_embeddings = [ 101 pooled_embeddings[name] for name in group_embedding_names 102 ] 103 pooled_embeddings_by_group[group_name] = torch.cat( 104 group_embeddings, dim=1 105 ) 106 107 return pooled_embeddings_by_group 108 109 dataset = RandomRecDataset( 110 keys=["a0", "a1", "b0", "b1", "b2"], 111 batch_size=4, 112 hash_size=2000, 113 ids_per_feature=3, 114 num_dense=0, 115 ) 116 di = iter(dataset) 117 118 # unsync should work 119 120 d1 = next(di).sparse_features.unsync() 121 d2 = next(di).sparse_features.unsync() 122 d3 = next(di).sparse_features.unsync() 123 124 r1 = f(d1) 125 r2 = f(d2) 126 r3 = f(d3) 127 128 self.assertEqual(counter.frame_count, 1) 129 counter.frame_count = 0 130 131 # sync should work too 132 133 d1 = next(di).sparse_features.sync() 134 d2 = next(di).sparse_features.sync() 135 d3 = next(di).sparse_features.sync() 136 137 r1 = f(d1) 138 r2 = f(d2) 139 r3 = f(d3) 140 141 self.assertEqual(counter.frame_count, 1) 142 143 # export only works with unsync 144 145 gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module 146 gm.print_readable() 147 148 self.assertEqual(gm(d1), r1) 149 self.assertEqual(gm(d2), r2) 150 self.assertEqual(gm(d3), r3) 151 152 def test_bucketize(self): 153 mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]}) 154 features = KeyedJaggedTensor.from_lengths_sync( 155 keys=["f1"], 156 values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), 157 lengths=torch.tensor([2, 0, 1, 1, 1, 3]), 158 weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), 159 ).unsync() 160 161 def f(x): 162 # This is a trick to populate the computed cache and instruct 163 # ShapeEnv that they're all sizey 164 x.to_dict() 165 return mod(x) 166 167 torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable() 168 169 @unittest.expectedFailure 170 def test_simple(self): 171 jag_tensor1 = KeyedJaggedTensor( 172 values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), 173 keys=["index_0", "index_1"], 174 lengths=torch.tensor([0, 0, 1, 1, 1, 3]), 175 ).sync() 176 177 # ordinarily, this would trigger one specialization 178 self.assertEqual(jag_tensor1.length_per_key(), [1, 5]) 179 180 counter = CompileCounter() 181 182 @torch._dynamo.optimize(counter, nopython=True) 183 def f(jag_tensor): 184 # The indexing here requires more symbolic reasoning 185 # and doesn't work right now 186 return jag_tensor["index_0"].values().sum() 187 188 f(jag_tensor1) 189 190 self.assertEqual(counter.frame_count, 1) 191 192 jag_tensor2 = KeyedJaggedTensor( 193 values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), 194 keys=["index_0", "index_1"], 195 lengths=torch.tensor([2, 0, 1, 1, 1, 3]), 196 ).sync() 197 198 f(jag_tensor2) 199 200 self.assertEqual(counter.frame_count, 1) 201 202 203if __name__ == "__main__": 204 from torch._dynamo.test_case import run_tests 205 206 run_tests() 207