xref: /aosp_15_r20/external/pytorch/test/export/test_schema.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2from torch._export.serde.schema_check import (
3    _Commit,
4    _diff_schema,
5    check,
6    SchemaUpdateError,
7    update_schema,
8)
9from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase
10
11
12class TestSchema(TestCase):
13    def test_schema_compatibility(self):
14        msg = """
15Detected an invalidated change to export schema. Please run the following script to update the schema:
16Example(s):
17    python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
18        """
19
20        if IS_FBCODE:
21            msg += """or
22    buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
23            """
24        try:
25            commit = update_schema()
26        except SchemaUpdateError as e:
27            self.fail(f"Failed to update schema: {e}\n{msg}")
28
29        self.assertEqual(commit.checksum_base, commit.checksum_result, msg)
30
31    def test_schema_diff(self):
32        additions, subtractions = _diff_schema(
33            {
34                "Type0": {"kind": "struct", "fields": {}},
35                "Type2": {
36                    "kind": "struct",
37                    "fields": {
38                        "field0": {"type": ""},
39                        "field2": {"type": ""},
40                        "field3": {"type": "", "default": "[]"},
41                    },
42                },
43            },
44            {
45                "Type2": {
46                    "kind": "struct",
47                    "fields": {
48                        "field1": {"type": "", "default": "0"},
49                        "field2": {"type": "", "default": "[]"},
50                        "field3": {"type": ""},
51                    },
52                },
53                "Type1": {"kind": "struct", "fields": {}},
54            },
55        )
56
57        self.assertEqual(
58            additions,
59            {
60                "Type1": {"kind": "struct", "fields": {}},
61                "Type2": {
62                    "fields": {
63                        "field1": {"type": "", "default": "0"},
64                        "field2": {"default": "[]"},
65                    },
66                },
67            },
68        )
69        self.assertEqual(
70            subtractions,
71            {
72                "Type0": {"kind": "struct", "fields": {}},
73                "Type2": {
74                    "fields": {
75                        "field0": {"type": ""},
76                        "field3": {"default": "[]"},
77                    },
78                },
79            },
80        )
81
82    def test_schema_check(self):
83        # Adding field without default value
84        dst = {
85            "Type2": {
86                "kind": "struct",
87                "fields": {
88                    "field0": {"type": ""},
89                },
90            },
91            "SCHEMA_VERSION": [3, 2],
92        }
93        src = {
94            "Type2": {
95                "kind": "struct",
96                "fields": {
97                    "field0": {"type": ""},
98                    "field1": {"type": ""},
99                },
100            },
101            "SCHEMA_VERSION": [3, 2],
102        }
103
104        additions, subtractions = _diff_schema(dst, src)
105
106        commit = _Commit(
107            result=src,
108            checksum_result="",
109            path="",
110            additions=additions,
111            subtractions=subtractions,
112            base=dst,
113            checksum_base="",
114        )
115        next_version, _ = check(commit)
116        self.assertEqual(next_version, [4, 1])
117
118        # Removing field
119        dst = {
120            "Type2": {
121                "kind": "struct",
122                "fields": {
123                    "field0": {"type": ""},
124                },
125            },
126            "SCHEMA_VERSION": [3, 2],
127        }
128        src = {
129            "Type2": {
130                "kind": "struct",
131                "fields": {},
132            },
133            "SCHEMA_VERSION": [3, 2],
134        }
135
136        additions, subtractions = _diff_schema(dst, src)
137
138        commit = _Commit(
139            result=src,
140            checksum_result="",
141            path="",
142            additions=additions,
143            subtractions=subtractions,
144            base=dst,
145            checksum_base="",
146        )
147        next_version, _ = check(commit)
148        self.assertEqual(next_version, [4, 1])
149
150        # Adding field with default value
151        dst = {
152            "Type2": {
153                "kind": "struct",
154                "fields": {
155                    "field0": {"type": ""},
156                },
157            },
158            "SCHEMA_VERSION": [3, 2],
159        }
160        src = {
161            "Type2": {
162                "kind": "struct",
163                "fields": {
164                    "field0": {"type": ""},
165                    "field1": {"type": "", "default": "[]"},
166                },
167            },
168            "SCHEMA_VERSION": [3, 2],
169        }
170
171        additions, subtractions = _diff_schema(dst, src)
172
173        commit = _Commit(
174            result=src,
175            checksum_result="",
176            path="",
177            additions=additions,
178            subtractions=subtractions,
179            base=dst,
180            checksum_base="",
181        )
182        next_version, _ = check(commit)
183        self.assertEqual(next_version, [3, 3])
184
185        # Changing field type
186        dst = {
187            "Type2": {
188                "kind": "struct",
189                "fields": {
190                    "field0": {"type": ""},
191                },
192            },
193            "SCHEMA_VERSION": [3, 2],
194        }
195        src = {
196            "Type2": {
197                "kind": "struct",
198                "fields": {
199                    "field0": {"type": "int"},
200                },
201            },
202            "SCHEMA_VERSION": [3, 2],
203        }
204
205        with self.assertRaises(SchemaUpdateError):
206            _diff_schema(dst, src)
207
208        # Adding new type.
209        dst = {
210            "Type2": {
211                "kind": "struct",
212                "fields": {
213                    "field0": {"type": ""},
214                },
215            },
216            "SCHEMA_VERSION": [3, 2],
217        }
218        src = {
219            "Type2": {
220                "kind": "struct",
221                "fields": {
222                    "field0": {"type": ""},
223                },
224            },
225            "Type1": {"kind": "struct", "fields": {}},
226            "SCHEMA_VERSION": [3, 2],
227        }
228
229        additions, subtractions = _diff_schema(dst, src)
230
231        commit = _Commit(
232            result=src,
233            checksum_result="",
234            path="",
235            additions=additions,
236            subtractions=subtractions,
237            base=dst,
238            checksum_base="",
239        )
240        next_version, _ = check(commit)
241        self.assertEqual(next_version, [3, 3])
242
243        # Removing a type.
244        dst = {
245            "Type2": {
246                "kind": "struct",
247                "fields": {
248                    "field0": {"type": ""},
249                },
250            },
251            "SCHEMA_VERSION": [3, 2],
252        }
253        src = {
254            "SCHEMA_VERSION": [3, 2],
255        }
256
257        additions, subtractions = _diff_schema(dst, src)
258
259        commit = _Commit(
260            result=src,
261            checksum_result="",
262            path="",
263            additions=additions,
264            subtractions=subtractions,
265            base=dst,
266            checksum_base="",
267        )
268        next_version, _ = check(commit)
269        self.assertEqual(next_version, [3, 3])
270
271        # Adding new field in union.
272        dst = {
273            "Type2": {
274                "kind": "union",
275                "fields": {
276                    "field0": {"type": ""},
277                },
278            },
279            "SCHEMA_VERSION": [3, 2],
280        }
281        src = {
282            "Type2": {
283                "kind": "union",
284                "fields": {
285                    "field0": {"type": ""},
286                    "field1": {"type": ""},
287                },
288            },
289            "SCHEMA_VERSION": [3, 2],
290        }
291
292        additions, subtractions = _diff_schema(dst, src)
293
294        commit = _Commit(
295            result=src,
296            checksum_result="",
297            path="",
298            additions=additions,
299            subtractions=subtractions,
300            base=dst,
301            checksum_base="",
302        )
303        next_version, _ = check(commit)
304        self.assertEqual(next_version, [3, 3])
305
306        # Removing a field in union.
307        dst = {
308            "Type2": {
309                "kind": "union",
310                "fields": {
311                    "field0": {"type": ""},
312                },
313            },
314            "SCHEMA_VERSION": [3, 2],
315        }
316        src = {
317            "Type2": {
318                "kind": "union",
319                "fields": {},
320            },
321            "SCHEMA_VERSION": [3, 2],
322        }
323
324        additions, subtractions = _diff_schema(dst, src)
325
326        commit = _Commit(
327            result=src,
328            checksum_result="",
329            path="",
330            additions=additions,
331            subtractions=subtractions,
332            base=dst,
333            checksum_base="",
334        )
335        next_version, _ = check(commit)
336        self.assertEqual(next_version, [4, 1])
337
338
339if __name__ == "__main__":
340    run_tests()
341