xref: /aosp_15_r20/external/pytorch/test/dynamo/test_torchrec.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import sys
3import unittest
4from typing import Dict, List
5
6import torch
7import torch._dynamo.config
8import torch._dynamo.test_case
9from torch import nn
10from torch._dynamo.test_case import TestCase
11from torch._dynamo.testing import CompileCounter
12from torch.testing._internal.common_utils import NoTest
13
14
15try:
16    from torchrec.datasets.random import RandomRecDataset
17    from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
18
19    HAS_TORCHREC = True
20except ImportError:
21    HAS_TORCHREC = False
22
23
24@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
25class BucketizeMod(torch.nn.Module):
26    def __init__(self, feature_boundaries: Dict[str, List[float]]):
27        super().__init__()
28        self.bucket_w = torch.nn.ParameterDict()
29        self.boundaries_dict = {}
30        for key, boundaries in feature_boundaries.items():
31            self.bucket_w[key] = torch.nn.Parameter(
32                torch.empty([len(boundaries) + 1]).fill_(1.0),
33                requires_grad=True,
34            )
35            buf = torch.tensor(boundaries, requires_grad=False)
36            self.register_buffer(
37                f"{key}_boundaries",
38                buf,
39                persistent=False,
40            )
41            self.boundaries_dict[key] = buf
42
43    def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
44        weights_list = []
45        for key, boundaries in self.boundaries_dict.items():
46            jt = features[key]
47            bucketized = torch.bucketize(jt.weights(), boundaries)
48            # doesn't super matter I guess
49            # hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries))
50            hashed = bucketized
51            weights = torch.gather(self.bucket_w[key], dim=0, index=hashed)
52            weights_list.append(weights)
53        return KeyedJaggedTensor(
54            keys=features.keys(),
55            values=features.values(),
56            weights=torch.cat(weights_list),
57            lengths=features.lengths(),
58            offsets=features.offsets(),
59            stride=features.stride(),
60            length_per_key=features.length_per_key(),
61        )
62
63
64if not HAS_TORCHREC:
65    print("torchrec not available, skipping tests", file=sys.stderr)
66    TestCase = NoTest  # noqa: F811
67
68
69@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec")
70class TorchRecTests(TestCase):
71    def test_pooled(self):
72        tables = [
73            (nn.EmbeddingBag(2000, 8), ["a0", "b0"]),
74            (nn.EmbeddingBag(2000, 8), ["a1", "b1"]),
75            (nn.EmbeddingBag(2000, 8), ["b2"]),
76        ]
77
78        embedding_groups = {
79            "a": ["a0", "a1"],
80            "b": ["b0", "b1", "b2"],
81        }
82
83        counter = CompileCounter()
84
85        @torch.compile(backend=counter, fullgraph=True, dynamic=True)
86        def f(id_list_features: KeyedJaggedTensor):
87            id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
88            pooled_embeddings = {}
89            # TODO: run feature processor
90            for emb_module, feature_names in tables:
91                features_dict = id_list_jt_dict
92                for feature_name in feature_names:
93                    f = features_dict[feature_name]
94                    pooled_embeddings[feature_name] = emb_module(
95                        f.values(), f.offsets()
96                    )
97
98            pooled_embeddings_by_group = {}
99            for group_name, group_embedding_names in embedding_groups.items():
100                group_embeddings = [
101                    pooled_embeddings[name] for name in group_embedding_names
102                ]
103                pooled_embeddings_by_group[group_name] = torch.cat(
104                    group_embeddings, dim=1
105                )
106
107            return pooled_embeddings_by_group
108
109        dataset = RandomRecDataset(
110            keys=["a0", "a1", "b0", "b1", "b2"],
111            batch_size=4,
112            hash_size=2000,
113            ids_per_feature=3,
114            num_dense=0,
115        )
116        di = iter(dataset)
117
118        # unsync should work
119
120        d1 = next(di).sparse_features.unsync()
121        d2 = next(di).sparse_features.unsync()
122        d3 = next(di).sparse_features.unsync()
123
124        r1 = f(d1)
125        r2 = f(d2)
126        r3 = f(d3)
127
128        self.assertEqual(counter.frame_count, 1)
129        counter.frame_count = 0
130
131        # sync should work too
132
133        d1 = next(di).sparse_features.sync()
134        d2 = next(di).sparse_features.sync()
135        d3 = next(di).sparse_features.sync()
136
137        r1 = f(d1)
138        r2 = f(d2)
139        r3 = f(d3)
140
141        self.assertEqual(counter.frame_count, 1)
142
143        # export only works with unsync
144
145        gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module
146        gm.print_readable()
147
148        self.assertEqual(gm(d1), r1)
149        self.assertEqual(gm(d2), r2)
150        self.assertEqual(gm(d3), r3)
151
152    def test_bucketize(self):
153        mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]})
154        features = KeyedJaggedTensor.from_lengths_sync(
155            keys=["f1"],
156            values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
157            lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
158            weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]),
159        ).unsync()
160
161        def f(x):
162            # This is a trick to populate the computed cache and instruct
163            # ShapeEnv that they're all sizey
164            x.to_dict()
165            return mod(x)
166
167        torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable()
168
169    @unittest.expectedFailure
170    def test_simple(self):
171        jag_tensor1 = KeyedJaggedTensor(
172            values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
173            keys=["index_0", "index_1"],
174            lengths=torch.tensor([0, 0, 1, 1, 1, 3]),
175        ).sync()
176
177        # ordinarily, this would trigger one specialization
178        self.assertEqual(jag_tensor1.length_per_key(), [1, 5])
179
180        counter = CompileCounter()
181
182        @torch._dynamo.optimize(counter, nopython=True)
183        def f(jag_tensor):
184            # The indexing here requires more symbolic reasoning
185            # and doesn't work right now
186            return jag_tensor["index_0"].values().sum()
187
188        f(jag_tensor1)
189
190        self.assertEqual(counter.frame_count, 1)
191
192        jag_tensor2 = KeyedJaggedTensor(
193            values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
194            keys=["index_0", "index_1"],
195            lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
196        ).sync()
197
198        f(jag_tensor2)
199
200        self.assertEqual(counter.frame_count, 1)
201
202
203if __name__ == "__main__":
204    from torch._dynamo.test_case import run_tests
205
206    run_tests()
207