# Owner(s): ["oncall: export"] from torch._export.serde.schema_check import ( _Commit, _diff_schema, check, SchemaUpdateError, update_schema, ) from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase class TestSchema(TestCase): def test_schema_compatibility(self): msg = """ Detected an invalidated change to export schema. Please run the following script to update the schema: Example(s): python scripts/export/update_schema.py --prefix """ if IS_FBCODE: msg += """or buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ """ try: commit = update_schema() except SchemaUpdateError as e: self.fail(f"Failed to update schema: {e}\n{msg}") self.assertEqual(commit.checksum_base, commit.checksum_result, msg) def test_schema_diff(self): additions, subtractions = _diff_schema( { "Type0": {"kind": "struct", "fields": {}}, "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, "field2": {"type": ""}, "field3": {"type": "", "default": "[]"}, }, }, }, { "Type2": { "kind": "struct", "fields": { "field1": {"type": "", "default": "0"}, "field2": {"type": "", "default": "[]"}, "field3": {"type": ""}, }, }, "Type1": {"kind": "struct", "fields": {}}, }, ) self.assertEqual( additions, { "Type1": {"kind": "struct", "fields": {}}, "Type2": { "fields": { "field1": {"type": "", "default": "0"}, "field2": {"default": "[]"}, }, }, }, ) self.assertEqual( subtractions, { "Type0": {"kind": "struct", "fields": {}}, "Type2": { "fields": { "field0": {"type": ""}, "field3": {"default": "[]"}, }, }, }, ) def test_schema_check(self): # Adding field without default value dst = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, "field1": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) # Removing field dst = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "struct", "fields": {}, }, "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) # Adding field with default value dst = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, "field1": {"type": "", "default": "[]"}, }, }, "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) # Changing field type dst = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "struct", "fields": { "field0": {"type": "int"}, }, }, "SCHEMA_VERSION": [3, 2], } with self.assertRaises(SchemaUpdateError): _diff_schema(dst, src) # Adding new type. dst = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "Type1": {"kind": "struct", "fields": {}}, "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) # Removing a type. dst = { "Type2": { "kind": "struct", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) # Adding new field in union. dst = { "Type2": { "kind": "union", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "union", "fields": { "field0": {"type": ""}, "field1": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [3, 3]) # Removing a field in union. dst = { "Type2": { "kind": "union", "fields": { "field0": {"type": ""}, }, }, "SCHEMA_VERSION": [3, 2], } src = { "Type2": { "kind": "union", "fields": {}, }, "SCHEMA_VERSION": [3, 2], } additions, subtractions = _diff_schema(dst, src) commit = _Commit( result=src, checksum_result="", path="", additions=additions, subtractions=subtractions, base=dst, checksum_base="", ) next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) if __name__ == "__main__": run_tests()