1# Owner(s): ["oncall: export"] 2from torch._export.serde.schema_check import ( 3 _Commit, 4 _diff_schema, 5 check, 6 SchemaUpdateError, 7 update_schema, 8) 9from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase 10 11 12class TestSchema(TestCase): 13 def test_schema_compatibility(self): 14 msg = """ 15Detected an invalidated change to export schema. Please run the following script to update the schema: 16Example(s): 17 python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory> 18 """ 19 20 if IS_FBCODE: 21 msg += """or 22 buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ 23 """ 24 try: 25 commit = update_schema() 26 except SchemaUpdateError as e: 27 self.fail(f"Failed to update schema: {e}\n{msg}") 28 29 self.assertEqual(commit.checksum_base, commit.checksum_result, msg) 30 31 def test_schema_diff(self): 32 additions, subtractions = _diff_schema( 33 { 34 "Type0": {"kind": "struct", "fields": {}}, 35 "Type2": { 36 "kind": "struct", 37 "fields": { 38 "field0": {"type": ""}, 39 "field2": {"type": ""}, 40 "field3": {"type": "", "default": "[]"}, 41 }, 42 }, 43 }, 44 { 45 "Type2": { 46 "kind": "struct", 47 "fields": { 48 "field1": {"type": "", "default": "0"}, 49 "field2": {"type": "", "default": "[]"}, 50 "field3": {"type": ""}, 51 }, 52 }, 53 "Type1": {"kind": "struct", "fields": {}}, 54 }, 55 ) 56 57 self.assertEqual( 58 additions, 59 { 60 "Type1": {"kind": "struct", "fields": {}}, 61 "Type2": { 62 "fields": { 63 "field1": {"type": "", "default": "0"}, 64 "field2": {"default": "[]"}, 65 }, 66 }, 67 }, 68 ) 69 self.assertEqual( 70 subtractions, 71 { 72 "Type0": {"kind": "struct", "fields": {}}, 73 "Type2": { 74 "fields": { 75 "field0": {"type": ""}, 76 "field3": {"default": "[]"}, 77 }, 78 }, 79 }, 80 ) 81 82 def test_schema_check(self): 83 # Adding field without default value 84 dst = { 85 "Type2": { 86 "kind": "struct", 87 "fields": { 88 "field0": {"type": ""}, 89 }, 90 }, 91 "SCHEMA_VERSION": [3, 2], 92 } 93 src = { 94 "Type2": { 95 "kind": "struct", 96 "fields": { 97 "field0": {"type": ""}, 98 "field1": {"type": ""}, 99 }, 100 }, 101 "SCHEMA_VERSION": [3, 2], 102 } 103 104 additions, subtractions = _diff_schema(dst, src) 105 106 commit = _Commit( 107 result=src, 108 checksum_result="", 109 path="", 110 additions=additions, 111 subtractions=subtractions, 112 base=dst, 113 checksum_base="", 114 ) 115 next_version, _ = check(commit) 116 self.assertEqual(next_version, [4, 1]) 117 118 # Removing field 119 dst = { 120 "Type2": { 121 "kind": "struct", 122 "fields": { 123 "field0": {"type": ""}, 124 }, 125 }, 126 "SCHEMA_VERSION": [3, 2], 127 } 128 src = { 129 "Type2": { 130 "kind": "struct", 131 "fields": {}, 132 }, 133 "SCHEMA_VERSION": [3, 2], 134 } 135 136 additions, subtractions = _diff_schema(dst, src) 137 138 commit = _Commit( 139 result=src, 140 checksum_result="", 141 path="", 142 additions=additions, 143 subtractions=subtractions, 144 base=dst, 145 checksum_base="", 146 ) 147 next_version, _ = check(commit) 148 self.assertEqual(next_version, [4, 1]) 149 150 # Adding field with default value 151 dst = { 152 "Type2": { 153 "kind": "struct", 154 "fields": { 155 "field0": {"type": ""}, 156 }, 157 }, 158 "SCHEMA_VERSION": [3, 2], 159 } 160 src = { 161 "Type2": { 162 "kind": "struct", 163 "fields": { 164 "field0": {"type": ""}, 165 "field1": {"type": "", "default": "[]"}, 166 }, 167 }, 168 "SCHEMA_VERSION": [3, 2], 169 } 170 171 additions, subtractions = _diff_schema(dst, src) 172 173 commit = _Commit( 174 result=src, 175 checksum_result="", 176 path="", 177 additions=additions, 178 subtractions=subtractions, 179 base=dst, 180 checksum_base="", 181 ) 182 next_version, _ = check(commit) 183 self.assertEqual(next_version, [3, 3]) 184 185 # Changing field type 186 dst = { 187 "Type2": { 188 "kind": "struct", 189 "fields": { 190 "field0": {"type": ""}, 191 }, 192 }, 193 "SCHEMA_VERSION": [3, 2], 194 } 195 src = { 196 "Type2": { 197 "kind": "struct", 198 "fields": { 199 "field0": {"type": "int"}, 200 }, 201 }, 202 "SCHEMA_VERSION": [3, 2], 203 } 204 205 with self.assertRaises(SchemaUpdateError): 206 _diff_schema(dst, src) 207 208 # Adding new type. 209 dst = { 210 "Type2": { 211 "kind": "struct", 212 "fields": { 213 "field0": {"type": ""}, 214 }, 215 }, 216 "SCHEMA_VERSION": [3, 2], 217 } 218 src = { 219 "Type2": { 220 "kind": "struct", 221 "fields": { 222 "field0": {"type": ""}, 223 }, 224 }, 225 "Type1": {"kind": "struct", "fields": {}}, 226 "SCHEMA_VERSION": [3, 2], 227 } 228 229 additions, subtractions = _diff_schema(dst, src) 230 231 commit = _Commit( 232 result=src, 233 checksum_result="", 234 path="", 235 additions=additions, 236 subtractions=subtractions, 237 base=dst, 238 checksum_base="", 239 ) 240 next_version, _ = check(commit) 241 self.assertEqual(next_version, [3, 3]) 242 243 # Removing a type. 244 dst = { 245 "Type2": { 246 "kind": "struct", 247 "fields": { 248 "field0": {"type": ""}, 249 }, 250 }, 251 "SCHEMA_VERSION": [3, 2], 252 } 253 src = { 254 "SCHEMA_VERSION": [3, 2], 255 } 256 257 additions, subtractions = _diff_schema(dst, src) 258 259 commit = _Commit( 260 result=src, 261 checksum_result="", 262 path="", 263 additions=additions, 264 subtractions=subtractions, 265 base=dst, 266 checksum_base="", 267 ) 268 next_version, _ = check(commit) 269 self.assertEqual(next_version, [3, 3]) 270 271 # Adding new field in union. 272 dst = { 273 "Type2": { 274 "kind": "union", 275 "fields": { 276 "field0": {"type": ""}, 277 }, 278 }, 279 "SCHEMA_VERSION": [3, 2], 280 } 281 src = { 282 "Type2": { 283 "kind": "union", 284 "fields": { 285 "field0": {"type": ""}, 286 "field1": {"type": ""}, 287 }, 288 }, 289 "SCHEMA_VERSION": [3, 2], 290 } 291 292 additions, subtractions = _diff_schema(dst, src) 293 294 commit = _Commit( 295 result=src, 296 checksum_result="", 297 path="", 298 additions=additions, 299 subtractions=subtractions, 300 base=dst, 301 checksum_base="", 302 ) 303 next_version, _ = check(commit) 304 self.assertEqual(next_version, [3, 3]) 305 306 # Removing a field in union. 307 dst = { 308 "Type2": { 309 "kind": "union", 310 "fields": { 311 "field0": {"type": ""}, 312 }, 313 }, 314 "SCHEMA_VERSION": [3, 2], 315 } 316 src = { 317 "Type2": { 318 "kind": "union", 319 "fields": {}, 320 }, 321 "SCHEMA_VERSION": [3, 2], 322 } 323 324 additions, subtractions = _diff_schema(dst, src) 325 326 commit = _Commit( 327 result=src, 328 checksum_result="", 329 path="", 330 additions=additions, 331 subtractions=subtractions, 332 base=dst, 333 checksum_base="", 334 ) 335 next_version, _ = check(commit) 336 self.assertEqual(next_version, [4, 1]) 337 338 339if __name__ == "__main__": 340 run_tests() 341