1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-ignore-all-errors 8 9# Copied over from caffe2/torch/_export/serde/serialize.py until dialects 10# are supported in torch export serializer. 11 12import base64 13import copy 14import copyreg 15import dataclasses 16import heapq 17import inspect 18import io 19import json 20import logging 21import math 22import operator 23import re 24import typing 25 26from contextlib import contextmanager 27from dataclasses import dataclass, field 28from enum import Enum 29from typing import ( 30 Any, 31 Callable, 32 cast, 33 Dict, 34 final, 35 Iterator, 36 List, 37 Optional, 38 Set, 39 Tuple, 40 Union, 41) 42 43import sympy 44 45import torch 46import torch.export.exported_program 47import torch.export.exported_program as ep 48from torch._export.serde.schema import SchemaVersion 49from torch._export.verifier import load_verifier 50from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 51from torch.fx.experimental import symbolic_shapes 52from torch.utils import _pytree as pytree 53from torch.utils._pytree import treespec_dumps, treespec_loads 54from torch.utils._sympy.numbers import int_oo 55from torch.utils._sympy.value_ranges import ValueRanges 56 57# pyre-ignore 58 59from .schema import ( # type: ignore[attr-defined] 60 Argument, 61 BufferMutationSpec, 62 ConstantInputSpec, 63 ConstantValue, 64 CustomObjArgument, 65 Device, 66 ExportedProgram, 67 GradientToParameterSpec, 68 GradientToUserInputSpec, 69 Graph, 70 GraphArgument, 71 GraphModule, 72 GraphSignature, 73 InputSpec, 74 InputToBufferSpec, 75 InputToCustomObjSpec, 76 InputTokenSpec, 77 InputToParameterSpec, 78 InputToTensorConstantSpec, 79 Layout, 80 LossOutputSpec, 81 MemoryFormat, 82 ModuleCallEntry, 83 ModuleCallSignature, 84 NamedArgument, 85 Node, 86 OptionalTensorArgument, 87 OutputSpec, 88 OutputTokenSpec, 89 RangeConstraint, 90 ScalarType, 91 SCHEMA_VERSION, 92 SymBool, 93 SymBoolArgument, 94 SymExpr, 95 SymExprHint, 96 SymInt, 97 SymIntArgument, 98 TensorArgument, 99 TensorMeta, 100 TokenArgument, 101 TREESPEC_VERSION, 102 UserInputMutationSpec, 103 UserInputSpec, 104 UserOutputSpec, 105) 106from .union import _Union 107 108 109__all__ = [ 110 "serialize", 111 "GraphModuleSerializer", 112 "ExportedProgramSerializer", 113 "GraphModuleDeserializer", 114 "ExportedProgramDeserializer", 115] 116 117from .upgrade import GraphModuleOpUpgrader 118 119log = logging.getLogger(__name__) 120 121 122class SerializeError(RuntimeError): 123 pass 124 125 126def _reverse_map(d: Dict[Any, Enum]): 127 return {v.value: k for k, v in d.items()} 128 129 130MetaType = Union[ 131 FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument 132] 133 134 135ST_DELIMITER = ";" 136 137_TORCH_TO_SERIALIZE_DTYPE = { 138 torch.uint8: ScalarType.BYTE, 139 torch.int8: ScalarType.CHAR, 140 torch.int16: ScalarType.SHORT, 141 torch.int32: ScalarType.INT, 142 torch.int64: ScalarType.LONG, 143 torch.float16: ScalarType.HALF, 144 torch.float32: ScalarType.FLOAT, 145 torch.float64: ScalarType.DOUBLE, 146 torch.complex32: ScalarType.COMPLEXHALF, 147 torch.complex64: ScalarType.COMPLEXFLOAT, 148 torch.complex128: ScalarType.COMPLEXDOUBLE, 149 torch.bool: ScalarType.BOOL, 150 torch.bfloat16: ScalarType.BFLOAT16, 151 torch.uint16: ScalarType.UINT16 152} 153 154 155_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] 156 157 158_TORCH_TO_SERIALIZE_LAYOUT = { 159 torch.sparse_coo: Layout.SparseCoo, 160 torch.sparse_csr: Layout.SparseCsr, 161 torch.sparse_csc: Layout.SparseCsc, 162 torch.sparse_bsr: Layout.SparseBsr, 163 torch.sparse_bsc: Layout.SparseBsc, 164 torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] 165 torch.strided: Layout.Strided, 166} 167 168 169_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] 170 171 172_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { 173 torch.contiguous_format: MemoryFormat.ContiguousFormat, 174 torch.channels_last: MemoryFormat.ChannelsLast, 175 torch.channels_last_3d: MemoryFormat.ChannelsLast3d, 176 torch.preserve_format: MemoryFormat.PreserveFormat, 177} 178 179 180_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] 181 182 183_SYM_INT_OPS = { 184 operator.mul, 185 operator.add, 186 operator.sub, 187 operator.floordiv, 188 operator.mod, 189 torch.sym_int, 190 torch.sym_float, 191 torch.sym_ite, 192 torch.sym_max, 193 torch.sym_min, 194 torch.sym_sqrt, 195} 196 197 198_SYM_BOOL_OPS = { 199 operator.eq, 200 operator.ne, 201 operator.le, 202 operator.ge, 203 operator.lt, 204 operator.gt, 205 torch.sym_not, 206} 207 208 209@dataclass 210class SerializedArtifact: 211 exported_program: bytes 212 state_dict: bytes 213 constants: bytes 214 example_inputs: bytes 215 216 217@dataclass 218class _SerializedProgram: 219 exported_program: ExportedProgram 220 state_dict: bytes 221 constants: bytes 222 example_inputs: bytes 223 224 225def deserialize_device(d: Device) -> torch.device: 226 if d.index is None: 227 return torch.device(type=d.type) # type: ignore[call-overload] 228 return torch.device(type=d.type, index=d.index) 229 230 231def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: 232 if isinstance(s, (torch.SymInt, int)): 233 if symbolic_shapes.is_concrete_int(s): 234 return SymInt.create(as_int=int(s)) 235 else: 236 assert isinstance(s, torch.SymInt) 237 if s.node.hint is None: 238 return SymInt.create(as_expr=SymExpr(str(s))) 239 else: 240 return SymInt.create( 241 as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint)) 242 ) 243 else: 244 raise SerializeError( 245 f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" 246 ) 247 248 249def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: 250 if isinstance(s, (torch.SymBool, bool)): 251 if symbolic_shapes.is_concrete_bool(s): 252 return SymBool.create(as_bool=bool(s)) 253 else: 254 return SymBool.create(as_expr=SymExpr(expr_str=str(s))) 255 else: 256 raise SerializeError( 257 f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" 258 ) 259 260 261def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: 262 """ 263 Extract a TensorMeta describing `t`. 264 """ 265 return TensorMeta( 266 dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], 267 sizes=[serialize_sym_int(s) for s in t.shape], 268 requires_grad=t.requires_grad, 269 device=Device(type=t.device.type, index=t.device.index), 270 strides=[serialize_sym_int(s) for s in t.stride()], 271 storage_offset=serialize_sym_int(0), # TODO needs to be fixed. 272 layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], 273 ) 274 275 276_CURRENT_DESERIALIZER: List["GraphModuleDeserializer"] = [] 277 278 279def _reduce_fake_tensor(fake_tensor: FakeTensor): 280 is_parameter = isinstance(fake_tensor, torch.nn.Parameter) 281 tensor_meta = serialize_tensor_meta(fake_tensor) 282 tensor_meta_bytes = json.dumps( 283 _dataclass_to_dict(tensor_meta), cls=EnumEncoder 284 ).encode("utf-8") 285 return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) 286 287 288def _reconstruct_fake_tensor( 289 serialized_tensor_meta: bytes, is_parameter: bool 290) -> FakeTensor: 291 # Deserialize the bytes into a TensorMeta 292 json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) 293 tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) 294 # Find the current fake mode 295 assert len(_CURRENT_DESERIALIZER) != 0, "Need access to current deserializer state" 296 fake_tensor = _CURRENT_DESERIALIZER[-1].deserialize_tensor_meta(tensor_meta) 297 if is_parameter: 298 fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] 299 return fake_tensor 300 301 302def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes: 303 assert ( 304 FakeTensor not in copyreg.dispatch_table 305 ), "Refusing to stomp on existing FakeTensor reducer" 306 try: 307 copyreg.pickle(FakeTensor, _reduce_fake_tensor) 308 buffer = io.BytesIO() 309 # This is a workaround for backend's tensor deserialization problem: 310 # unpickleTensor() always create a tensor on the device where it was originally saved 311 # This behavior is bad for multi-gpu training, as we wish to directly load the tensor 312 # on the designated device. 313 # For now, we simply move the tensor to cpu before saving. 314 # TODO: this should be fixed by deserialization instead. 315 torch.save(artifact, buffer) 316 return buffer.getvalue() 317 finally: 318 del copyreg.dispatch_table[FakeTensor] 319 320 321def deserialize_torch_artifact( 322 serialized: Union[Dict[str, Any], Tuple[Any, ...], bytes] 323): 324 if isinstance(serialized, (dict, tuple)): 325 return serialized 326 if len(serialized) == 0: 327 return {} 328 buffer = io.BytesIO(serialized) 329 buffer.seek(0) 330 artifact = torch.load(buffer) 331 assert isinstance(artifact, (tuple, dict)) 332 return artifact 333 334 335def _sympy_int_to_int(val: sympy.Expr, adjust: str): 336 # Convert simple sympy Integers into concrete int 337 if val in (sympy.oo, int_oo): 338 return math.inf 339 if val in (-sympy.oo, -int_oo): 340 return -math.inf 341 if isinstance(val, sympy.Integer): 342 return int(val) 343 344 # TODO: Remove this adjustment when Ed gets rid of fractional ranges 345 log.warning( 346 "Export constraints cannot be non-integer expressions. Found " 347 "type %s, and value %s. We will attempt to %s " 348 "this value.", 349 type(val), 350 val, 351 adjust, 352 ) 353 354 if adjust == "floor": 355 return math.floor(val) 356 elif adjust == "ceil": 357 return math.ceil(val) 358 else: 359 raise RuntimeError(f"Got invalid adjustment {adjust}") 360 361 362def _int_to_sympy_int(val) -> sympy.Expr: 363 # Convert concrete int into simple sympy Integers 364 if val == math.inf: 365 return int_oo 366 if val == -math.inf: 367 return -int_oo 368 return sympy.Integer(val) 369 370 371def serialize_range_constraints( 372 range_constraints: Dict[sympy.Symbol, ValueRanges] 373) -> Dict[str, RangeConstraint]: 374 return { 375 str(k): RangeConstraint( 376 _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] 377 _sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type] 378 ) 379 for k, v in range_constraints.items() 380 } 381 382 383def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool: 384 returns = target._schema.returns 385 return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType) 386 387 388def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool: 389 returns = target._schema.returns 390 if len(returns) != 1: 391 return False 392 return_type = returns[0].real_type 393 return isinstance(return_type, torch.ListType) and isinstance( 394 return_type.getElementType(), torch.TensorType 395 ) 396 397 398def _output_node_at_index(node, index): 399 for user in node.users: 400 assert user.target is operator.getitem, f"{user} is not a getitem node" 401 if index == user.args[1]: 402 return user 403 return None 404 405 406@dataclass 407class GraphState: 408 inputs: List[Argument] = field(default_factory=list) 409 outputs: List[Argument] = field(default_factory=list) 410 nodes: List[Node] = field(default_factory=list) 411 tensor_values: Dict[str, TensorMeta] = field(default_factory=dict) 412 sym_int_values: Dict[str, SymInt] = field(default_factory=dict) 413 sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) 414 is_single_tensor_return: bool = False 415 custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) 416 417 418class Final(type): 419 def __new__(metacls, name, bases, classdict): 420 for b in bases: 421 if isinstance(b, Final): 422 raise TypeError(f"type '{b.__name__}' is not an acceptable base type") 423 return type.__new__(metacls, name, bases, dict(classdict)) 424 425 426class GraphModuleSerializer: 427 def __init__( 428 self, 429 graph_signature: ep.ExportGraphSignature, 430 module_call_graph: List[ep.ModuleCallEntry], 431 ): 432 self.graph_state = GraphState() 433 self.graph_signature = graph_signature 434 self.module_call_graph = module_call_graph 435 self.custom_objs: Dict[str, torch._C.ScriptObject] = {} 436 437 @contextmanager 438 def save_graph_state(self): 439 saved = self.graph_state 440 self.graph_state = GraphState() 441 try: 442 yield 443 finally: 444 self.graph_state = saved 445 446 def handle_placeholder(self, node: torch.fx.Node): 447 assert node.op == "placeholder" 448 if isinstance(node.meta["val"], torch.Tensor): 449 graph_input = Argument.create(as_tensor=TensorArgument(name=node.name)) 450 self.graph_state.tensor_values[node.name] = serialize_tensor_meta( 451 node.meta["val"] 452 ) 453 elif isinstance(node.meta["val"], torch.SymInt): 454 graph_input = Argument.create( 455 as_sym_int=SymIntArgument.create(as_name=node.name) 456 ) 457 self.graph_state.sym_int_values[node.name] = serialize_sym_int( 458 node.meta["val"] 459 ) 460 elif isinstance(node.meta["val"], (int, bool, str, float, type(None))): 461 graph_input = self.serialize_input(node.meta["val"]) 462 elif isinstance(node.meta["val"], ep.CustomObjArgument): 463 class_fqn = node.meta["val"].class_fqn 464 graph_input = Argument.create( 465 as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) 466 ) 467 self.graph_state.custom_obj_values[node.name] = ( 468 self.serialize_script_obj_meta(node.meta["val"]) 469 ) 470 else: 471 raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") 472 self.graph_state.inputs.append(graph_input) 473 474 def handle_output(self, node: torch.fx.Node): 475 assert node.op == "output" 476 assert len(node.args) == 1, "FX.Node's args should have one arg" 477 node_args = node.args[0] 478 if isinstance(node_args, torch.fx.Node): 479 # For singleton tensor returns 480 self.graph_state.is_single_tensor_return = True 481 self.graph_state.outputs = [self.serialize_input(node_args)] 482 else: 483 assert isinstance(node_args, (tuple, list)) 484 self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] 485 486 def serialize_operator(self, target) -> str: 487 if isinstance(target, str): 488 return target 489 elif target.__module__.startswith("torch._ops"): 490 # TODO(zhxchen17) Maybe provide a function name helper in FX. 491 # From torch.fx.node._get_qualified_name 492 module = target.__module__.replace("torch._ops", "torch.ops") 493 return f"{module}.{target.__name__}" 494 else: # TODO(zhxchen17) Don't catch all here. 495 return f"{target.__module__}.{target.__name__}" 496 497 def handle_call_function(self, node: torch.fx.Node): 498 assert node.op == "call_function" 499 500 # getitem has been handled in the producer node, skip it here 501 if node.target is operator.getitem: 502 return 503 504 if node.target in _SYM_INT_OPS: 505 assert len(node.kwargs) == 0 506 meta_val = node.meta["val"] 507 ex_node = Node( 508 target=self.serialize_operator(node.target), 509 inputs=self.serialize_sym_op_inputs(node.target, node.args), 510 outputs=[ 511 Argument.create( 512 as_sym_int=self.serialize_sym_int_output(node.name, meta_val) 513 ) 514 ], 515 metadata=self.serialize_metadata(node), 516 ) 517 elif node.target in _SYM_BOOL_OPS: 518 assert len(node.kwargs) == 0 519 meta_val = node.meta["val"] 520 ex_node = Node( 521 target=self.serialize_operator(node.target), 522 inputs=self.serialize_sym_op_inputs(node.target, node.args), 523 outputs=[ 524 Argument.create( 525 as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val) 526 ) 527 ], 528 metadata=self.serialize_metadata(node), 529 ) 530 elif isinstance(node.target, torch._ops.OpOverload): 531 ex_node = Node( 532 target=self.serialize_operator(node.target), 533 inputs=self.serialize_inputs(node.target, node.args, node.kwargs), 534 outputs=self.serialize_outputs(node), 535 # TODO: create a new tensor_values here, meta might have faketensor info 536 metadata=self.serialize_metadata(node), 537 ) 538 elif isinstance(node.target, torch._ops.HigherOrderOperator): 539 ex_node = Node( 540 target=self.serialize_operator(node.target), 541 inputs=self.serialize_hoo_inputs(node.args, node.kwargs), 542 outputs=self.serialize_hoo_outputs(node), 543 metadata=self.serialize_metadata(node), 544 ) 545 else: 546 raise SerializeError(f"Serializing {node.target} is not supported") 547 548 self.graph_state.nodes.append(ex_node) 549 550 def handle_get_attr(self, node): 551 pass 552 553 def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: 554 ret = {} 555 if stack_trace := node.meta.get("stack_trace"): 556 ret["stack_trace"] = stack_trace 557 558 if nn_module_stack := node.meta.get("nn_module_stack"): 559 560 def export_nn_module_stack(val): 561 assert isinstance(val, tuple) and len(val) == 2 562 path, ty = val 563 564 assert isinstance(path, str) 565 566 # node.meta["nn_module_stack"] could have two forms: 567 # 1. (path: str, module_type: 'type'), e.g. 568 # ('', <class 'sigmoid.inference.MySimpleModel'>) 569 # 2. (path: str, module_type: str), e.g. 570 # ('', 'sigmoid.inference.MySimpleModel') 571 # ExportedProgram directly produced by torch.export() has form 1 572 # ExportedProgram deserialized from disk has form 2 573 # TODO: This is not ideal, we should fix this. 574 if isinstance(ty, str): 575 normalized_ty = ty 576 else: 577 normalized_ty = ty.__module__ + "." + ty.__qualname__ 578 579 return path + "," + normalized_ty 580 581 # Serialize to "key,orig_path,type_str" 582 nn_module_list = [ 583 f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() 584 ] 585 ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) 586 587 if source_fn_st := node.meta.get("source_fn_stack"): 588 source_fn_list = [ 589 f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" 590 for source_fn in source_fn_st 591 ] 592 ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) 593 594 if torch_fn := node.meta.get("torch_fn"): 595 ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) 596 597 return ret 598 599 def serialize_script_obj_meta( 600 self, script_obj_meta: ep.CustomObjArgument 601 ) -> CustomObjArgument: 602 return CustomObjArgument( 603 name=script_obj_meta.name, 604 class_fqn=script_obj_meta.class_fqn, 605 ) 606 607 def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: 608 serialized_args = [] 609 args_names = inspect.signature(op).parameters.keys() 610 for args_name, arg in zip(args_names, args): 611 serialized_args.append( 612 NamedArgument(name=args_name, arg=self.serialize_input(arg)) 613 ) 614 return serialized_args 615 616 def serialize_inputs( 617 self, target: torch._ops.OpOverload, args, kwargs=None 618 ) -> List[NamedArgument]: 619 assert isinstance(target, torch._ops.OpOverload) 620 kwargs = kwargs or {} 621 serialized_args = [] 622 for i, schema_arg in enumerate(target._schema.arguments): 623 if schema_arg.name in kwargs: 624 serialized_args.append( 625 NamedArgument( 626 name=schema_arg.name, 627 arg=self.serialize_input( 628 kwargs[schema_arg.name], schema_arg.type 629 ), 630 ) 631 ) 632 elif not schema_arg.kwarg_only and i < len(args): 633 serialized_args.append( 634 NamedArgument( 635 name=schema_arg.name, 636 arg=self.serialize_input(args[i], schema_arg.type), 637 ) 638 ) 639 else: 640 # We intentionally don't serialize the missing arguments 641 # with default values 642 pass 643 644 return serialized_args 645 646 def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: 647 """ 648 For serializing HOO inputs since HOOs do not have a schema. 649 """ 650 inputs = [ 651 NamedArgument( 652 name="", 653 arg=self.serialize_input(a), 654 ) 655 for a in args 656 ] 657 inputs.extend( 658 [ 659 NamedArgument(name=name, arg=self.serialize_input(a)) 660 for name, a in kwargs.items() 661 ] 662 ) 663 return inputs 664 665 def is_sym_int_arg(self, arg) -> bool: 666 return isinstance(arg, int) or ( 667 isinstance(arg, torch.fx.Node) 668 and arg.name in self.graph_state.sym_int_values 669 ) 670 671 def is_sym_bool_arg(self, arg) -> bool: 672 return isinstance(arg, bool) or ( 673 isinstance(arg, torch.fx.Node) 674 and arg.name in self.graph_state.sym_bool_values 675 ) 676 677 def serialize_input( 678 self, arg, arg_type: Optional[torch._C.Argument] = None 679 ) -> Argument: 680 import torch._inductor.ir as inductor_ir 681 682 inductor_tensor_buffers = ( 683 inductor_ir.Buffer, 684 inductor_ir.ReinterpretView, 685 ) 686 687 if isinstance(arg, torch.fx.Node): 688 if arg.op == "get_attr": 689 assert isinstance(arg.target, str) 690 attr = getattr(arg.graph.owning_module, arg.target) 691 692 if isinstance(attr, torch.Tensor): 693 raise SerializeError( 694 "getattr nodes containing tensors should not appear in the graph" 695 ) 696 elif isinstance(attr, torch.fx.GraphModule): 697 with self.save_graph_state(): 698 graph = self.serialize_graph(attr) 699 return Argument.create( 700 as_graph=GraphArgument(name=arg.target, graph=graph) 701 ) 702 else: 703 raise SerializeError( 704 f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" 705 ) 706 elif self.is_sym_int_arg(arg): 707 return Argument.create( 708 as_sym_int=SymIntArgument.create(as_name=arg.name) 709 ) 710 elif self.is_sym_bool_arg(arg): 711 return Argument.create( 712 as_sym_bool=SymBoolArgument.create(as_name=arg.name) 713 ) 714 else: 715 if isinstance(arg.meta["val"], ep.CustomObjArgument): 716 return Argument.create( 717 as_custom_obj=CustomObjArgument( 718 name=arg.name, class_fqn=arg.meta["val"].class_fqn 719 ) 720 ) 721 return Argument.create(as_tensor=TensorArgument(name=arg.name)) 722 elif isinstance(arg, inductor_tensor_buffers): 723 # Other branches are for arguments in fx node. 724 # This is a special branch for handling buffers (representing tensor arguments) 725 # for inductor's ExternalFallbackNode 726 # export_extern_kernel_node() is using this function to serialize arguments 727 arg_name = arg.get_name() 728 assert arg_name is not None, "Buffer must have valid name" 729 return Argument.create(as_tensor=TensorArgument(name=arg_name)) 730 elif isinstance(arg, torch.SymInt): 731 # This is a special branch for handling SymInt args in inductor's 732 # ExternalFallbackNode. 733 # For regular FX graph, SymInt arg should be a fx.Node with 734 # self.is_sym_int_arg(arg) being true 735 return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) 736 elif isinstance(arg, bool): 737 return Argument.create(as_bool=arg) 738 elif isinstance(arg, str): 739 return Argument.create(as_string=arg) 740 elif isinstance(arg, int): 741 return Argument.create(as_int=arg) 742 elif isinstance(arg, float): 743 return Argument.create(as_float=arg) 744 elif arg is None: 745 return Argument.create(as_none=()) 746 elif isinstance(arg, (list, tuple)): 747 if len(arg) == 0: 748 if arg_type is not None: 749 if isinstance(arg_type, torch.OptionalType): 750 arg_type = arg_type.getElementType() # type: ignore[assignment] 751 assert isinstance(arg_type, torch.ListType) 752 elem_type = arg_type.getElementType() 753 if isinstance(elem_type, torch.OptionalType): 754 elem_type = elem_type.getElementType() 755 756 if isinstance(elem_type, torch.BoolType): 757 return Argument.create(as_bools=[]) 758 elif isinstance(elem_type, torch.IntType): 759 return Argument.create(as_ints=[]) 760 elif isinstance(elem_type, torch.FloatType): 761 return Argument.create(as_floats=[]) 762 elif isinstance(elem_type, torch.StringType): 763 return Argument.create(as_strings=[]) 764 elif isinstance(elem_type, torch.TensorType): 765 return Argument.create(as_tensors=[]) 766 else: 767 # I believe empty symint lists default to ints, but 768 # please file an issue if this is not the case 769 raise SerializeError(f"Empty list with type {elem_type} nyi.") 770 else: 771 # We could serialize this by default to a tensor list. This 772 # is needed in the HOO case 773 log.warning( 774 "Unsure how to serialize the given empty list, " 775 "as we don't know what is the type of this argument. " 776 "Serializing it as a tensor list by default." 777 ) 778 return Argument.create(as_tensors=[]) 779 780 # Must check bool first, as bool is also treated as int 781 if all(isinstance(a, bool) for a in arg): 782 return Argument.create(as_bools=list(arg)) 783 elif all(isinstance(a, int) for a in arg): 784 return Argument.create(as_ints=list(arg)) 785 elif all(isinstance(a, float) for a in arg): 786 return Argument.create(as_floats=list(arg)) 787 elif all(isinstance(a, str) for a in arg): 788 return Argument.create(as_strings=list(arg)) 789 elif all(isinstance(a, torch.SymInt) for a in arg): 790 # This is a special branch for handling SymInt args in inductor's 791 # ExternalFallbackNode. 792 # For regular FX graph, SymInt arg should be a fx.Node with 793 # self.is_sym_int_arg(arg) being true 794 return Argument.create( 795 as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg] 796 ) 797 elif all(self.is_sym_int_arg(a) for a in arg): 798 # list of sym_ints 799 values = [] 800 for a in arg: 801 if isinstance(a, torch.fx.Node): 802 values.append(SymIntArgument.create(as_name=a.name)) 803 elif isinstance(a, int): 804 values.append(SymIntArgument.create(as_int=a)) 805 return Argument.create(as_sym_ints=values) 806 elif all(self.is_sym_bool_arg(a) for a in arg): 807 # list of sym_bools 808 values = [] 809 for a in arg: 810 if isinstance(a, torch.fx.Node): 811 values.append(SymBoolArgument.create(as_name=a.name)) 812 elif isinstance(a, bool): 813 values.append(SymBoolArgument.create(as_bool=a)) 814 return Argument.create(as_sym_bools=values) 815 elif all(isinstance(a, torch.fx.Node) for a in arg): 816 # list of tensors 817 arguments = [] 818 for a in arg: 819 if a.op == "get_attr": 820 raise SerializeError( 821 "getattr nodes containing tensors should not appear in the graph" 822 ) 823 arguments.append(TensorArgument(name=a.name)) 824 return Argument.create(as_tensors=arguments) 825 elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): 826 # list of optional tensors 827 def serialize_optional_tensor_args(a): 828 if a is None: 829 return OptionalTensorArgument.create(as_none=()) 830 elif isinstance(a, torch.fx.Node): 831 return OptionalTensorArgument.create( 832 as_tensor=TensorArgument(name=a.name) 833 ) 834 else: 835 raise SerializeError(f"Unsupported list/tuple argument: {a}") 836 837 return Argument.create( 838 as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) 839 ) 840 elif all(isinstance(a, inductor_tensor_buffers) for a in arg): 841 # list of inductor buffers 842 return Argument.create( 843 as_tensors=[TensorArgument(name=a.get_name()) for a in arg], 844 ) 845 elif all( 846 isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg 847 ): 848 # list of inductor buffers as optional tensors 849 def serialize_optional_tensor_args(a): 850 if a is None: 851 return OptionalTensorArgument.create(as_none=()) 852 elif isinstance(a, inductor_tensor_buffers): 853 return OptionalTensorArgument.create( 854 as_tensor=TensorArgument(name=a.get_name()) 855 ) 856 else: 857 raise SerializeError(f"Unsupported list/tuple argument: {a}") 858 859 return Argument.create( 860 as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) 861 ) 862 else: 863 raise SerializeError( 864 f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" 865 ) 866 elif isinstance(arg, torch.dtype): 867 return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) 868 elif isinstance(arg, torch.device): 869 return Argument.create(as_device=Device(type=arg.type, index=arg.index)) 870 elif isinstance(arg, torch.memory_format): 871 return Argument.create( 872 as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg] 873 ) 874 elif isinstance(arg, torch.layout): 875 return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) 876 elif isinstance(arg, torch._C.ScriptObject): 877 if not ( 878 arg._has_method("__getstate__") # type: ignore[attr-defined] 879 and arg._has_method("__setstate__") # type: ignore[attr-defined] 880 ): 881 raise SerializeError( 882 f"Unable to serialize custom class {arg}. Please define " 883 "serialization methods via def_pickle()." 884 ) 885 # Custom objects through torchind are serializable with pickle, 886 # through implementing the .def_pickle function. This should result 887 # in the object containing a __getstate__ and __setstate__ 888 # serialize/deserialize function. 889 custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" 890 self.custom_objs[custom_obj_name] = arg 891 class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] 892 return Argument.create( 893 as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) 894 ) 895 elif isinstance(arg, torch._ops.OpOverload): 896 return Argument.create(as_operator=self.serialize_operator(arg)) 897 else: 898 raise SerializeError(f"Unsupported argument type: {type(arg)}") 899 900 def serialize_tensor_output(self, name, meta_val) -> TensorArgument: 901 assert name not in self.graph_state.tensor_values 902 self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) 903 return TensorArgument(name=name) 904 905 def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: 906 assert name not in self.graph_state.sym_int_values 907 self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) 908 return SymIntArgument.create(as_name=name) 909 910 def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: 911 assert name not in self.graph_state.sym_bool_values 912 self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) 913 return SymBoolArgument.create(as_name=name) 914 915 def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: 916 if spec.kind == ep.InputKind.USER_INPUT: 917 if isinstance(spec.arg, ep.ConstantArgument): 918 if isinstance(spec.arg.value, int): 919 constant_spec = ConstantValue.create(as_int=spec.arg.value) 920 elif isinstance(spec.arg.value, bool): 921 constant_spec = ConstantValue.create(as_bool=spec.arg.value) 922 elif isinstance(spec.arg.value, str): 923 constant_spec = ConstantValue.create(as_string=spec.arg.value) 924 elif isinstance(spec.arg.value, float): 925 constant_spec = ConstantValue.create(as_float=spec.arg.value) 926 elif spec.arg.value is None: 927 constant_spec = ConstantValue.create(as_none=()) 928 else: 929 raise SerializeError( 930 f"Unhandled constant input {spec.arg.value} to serialize" 931 ) 932 return InputSpec.create( 933 constant_input=ConstantInputSpec( 934 name=spec.arg.name, value=constant_spec 935 ) 936 ) 937 else: 938 return InputSpec.create( 939 user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg)) 940 ) 941 elif spec.kind == ep.InputKind.PARAMETER: 942 assert spec.target is not None 943 assert isinstance(spec.arg, ep.TensorArgument) 944 return InputSpec.create( 945 parameter=InputToParameterSpec( 946 arg=TensorArgument(name=spec.arg.name), 947 parameter_name=spec.target, 948 ) 949 ) 950 elif spec.kind == ep.InputKind.BUFFER: 951 assert spec.target is not None 952 assert isinstance(spec.arg, ep.TensorArgument) 953 assert spec.persistent is not None 954 return InputSpec.create( 955 buffer=InputToBufferSpec( 956 arg=TensorArgument(name=spec.arg.name), 957 buffer_name=spec.target, 958 persistent=spec.persistent, 959 ) 960 ) 961 elif spec.kind == ep.InputKind.CONSTANT_TENSOR: 962 assert spec.target is not None 963 assert isinstance(spec.arg, ep.TensorArgument) 964 return InputSpec.create( 965 tensor_constant=InputToTensorConstantSpec( 966 arg=TensorArgument(name=spec.arg.name), 967 tensor_constant_name=spec.target, 968 ) 969 ) 970 elif spec.kind == ep.InputKind.CUSTOM_OBJ: 971 assert spec.target is not None 972 assert isinstance(spec.arg, ep.CustomObjArgument) 973 return InputSpec.create( 974 custom_obj=InputToCustomObjSpec( 975 arg=CustomObjArgument( 976 name=spec.arg.name, class_fqn=spec.arg.class_fqn 977 ), 978 custom_obj_name=spec.target, 979 ) 980 ) 981 elif spec.kind == ep.InputKind.TOKEN: 982 assert isinstance(spec.arg, ep.TokenArgument) 983 return InputSpec.create( 984 token=InputTokenSpec( 985 arg=TokenArgument(name=spec.arg.name), 986 ) 987 ) 988 else: 989 raise AssertionError(f"Unknown argument kind: {spec}") 990 991 def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: 992 if spec.kind == ep.OutputKind.USER_OUTPUT: 993 return OutputSpec.create( 994 user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg)) 995 ) 996 elif spec.kind == ep.OutputKind.LOSS_OUTPUT: 997 assert isinstance(spec.arg, ep.TensorArgument) 998 return OutputSpec.create( 999 loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name)) 1000 ) 1001 elif spec.kind == ep.OutputKind.BUFFER_MUTATION: 1002 assert spec.target is not None 1003 assert isinstance(spec.arg, ep.TensorArgument) 1004 return OutputSpec.create( 1005 buffer_mutation=BufferMutationSpec( 1006 arg=TensorArgument(name=spec.arg.name), 1007 buffer_name=spec.target, 1008 ) 1009 ) 1010 elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: 1011 assert spec.target is not None 1012 assert isinstance(spec.arg, ep.TensorArgument) 1013 return OutputSpec.create( 1014 gradient_to_parameter=GradientToParameterSpec( 1015 arg=TensorArgument(name=spec.arg.name), 1016 parameter_name=spec.target, 1017 ) 1018 ) 1019 elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: 1020 assert spec.target is not None 1021 assert isinstance(spec.arg, ep.TensorArgument) 1022 return OutputSpec.create( 1023 gradient_to_user_input=GradientToUserInputSpec( 1024 arg=TensorArgument(name=spec.arg.name), 1025 user_input_name=spec.target, 1026 ) 1027 ) 1028 elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: 1029 assert spec.target is not None 1030 assert isinstance(spec.arg, ep.TensorArgument) 1031 return OutputSpec.create( 1032 user_input_mutation=UserInputMutationSpec( 1033 arg=TensorArgument(name=spec.arg.name), 1034 user_input_name=spec.target, 1035 ) 1036 ) 1037 elif spec.kind == ep.OutputKind.TOKEN: 1038 assert isinstance(spec.arg, ep.TokenArgument) 1039 return OutputSpec.create( 1040 token=OutputTokenSpec( 1041 arg=TokenArgument(name=spec.arg.name), 1042 ) 1043 ) 1044 else: 1045 raise AssertionError(f"Unknown argument kind: {spec}") 1046 1047 def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: 1048 return GraphSignature( 1049 input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], 1050 output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], 1051 ) 1052 1053 def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: 1054 if isinstance(x, ep.TensorArgument): 1055 return Argument.create(as_tensor=TensorArgument(name=x.name)) 1056 elif isinstance(x, ep.SymIntArgument): 1057 return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) 1058 elif isinstance(x, ep.ConstantArgument): 1059 return self.serialize_input(x.value) 1060 elif isinstance(x, ep.CustomObjArgument): 1061 return Argument.create( 1062 as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn) 1063 ) 1064 else: 1065 raise AssertionError("TODO") 1066 1067 def serialize_module_call_signature( 1068 self, module_call_signature: ep.ModuleCallSignature 1069 ) -> ModuleCallSignature: 1070 return ModuleCallSignature( 1071 inputs=[ 1072 self.serialize_argument_spec(x) for x in module_call_signature.inputs 1073 ], 1074 outputs=[ 1075 self.serialize_argument_spec(x) for x in module_call_signature.outputs 1076 ], 1077 in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), 1078 out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), 1079 ) 1080 1081 def serialize_module_call_graph( 1082 self, module_call_graph: List[ep.ModuleCallEntry] 1083 ) -> List[ModuleCallEntry]: 1084 return [ 1085 ModuleCallEntry( 1086 fqn=entry.fqn, 1087 signature=( 1088 self.serialize_module_call_signature(entry.signature) 1089 if entry.signature 1090 else None 1091 ), 1092 ) 1093 for entry in module_call_graph 1094 ] 1095 1096 def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: 1097 """For a given node, return the dataclass representing its output values. 1098 1099 [NOTE: Multiple outputs] We handle aggregates differently than FX. For 1100 FX, it looks like: 1101 1102 x = call_function("multiple_return", ...) 1103 element0 = call_function(getitem, x, 0) 1104 foo = call_function("use_output", element0) 1105 1106 We do not want the intermediate `getitem` call, so our serialized thing looks like: 1107 1108 element0, element1, element2 = call_function("multiple_return", ...) 1109 foo = call_function("use_output", element0) 1110 1111 We want names to be consistent across these two schemes, so that we can 1112 mostly reuse the names coming from FX. This function computes a mapping from 1113 the FX representation to our representation, preserving the names. 1114 """ 1115 assert node.op == "call_function" and isinstance( 1116 node.target, torch._ops.OpOverload 1117 ) 1118 1119 assert isinstance(node.target, torch._ops.OpOverload) 1120 returns = node.target._schema.returns 1121 1122 if len(returns) == 0: 1123 return [] 1124 1125 meta_val = node.meta["val"] 1126 1127 # Check single value return 1128 if _is_single_tensor_list_return(node.target): 1129 # e.g "-> Tensor[]" 1130 tensor_args = [] 1131 for idx, meta in enumerate(meta_val): 1132 user_node = _output_node_at_index(node, idx) 1133 name = ( 1134 user_node.name 1135 if user_node is not None 1136 else f"{node.name}_unused_{idx}" 1137 ) 1138 tensor_args.append(self.serialize_tensor_output(name, meta)) 1139 return [Argument.create(as_tensors=tensor_args)] 1140 elif len(returns) == 1: 1141 return [self.serialize_output(node.name, meta_val)] 1142 1143 # There are a two possibilities at this point: 1144 # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" 1145 # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" 1146 # 1147 # Either way, start by gathering a list of TensorArguments with the correct names. 1148 # For consistent naming with FX, consult the downstream `getitem` node and 1149 # make sure our outputs have the same name. 1150 1151 output_arguments = [] 1152 for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): 1153 if meta is None: 1154 assert isinstance( 1155 return_schema.real_type, (torch.OptionalType, torch.TensorType) 1156 ) 1157 # When the return type is annoated as Tensor type, the op can also return an 1158 # undefined Tensor which will be implicitly converted to None in Python. 1159 output_arguments.append(Argument.create(as_none=())) 1160 elif isinstance(meta, FakeTensor): 1161 assert isinstance( 1162 return_schema.real_type, (torch.OptionalType, torch.TensorType) 1163 ) 1164 user_node = _output_node_at_index(node, idx) 1165 name = ( 1166 user_node.name 1167 if user_node is not None 1168 else f"{node.name}_unused_{idx}" 1169 ) 1170 output_arguments.append(self.serialize_output(name, meta)) 1171 elif isinstance(meta, list): 1172 # for List[Tensor] return type 1173 assert isinstance( 1174 return_schema.real_type, torch.ListType 1175 ) and isinstance( 1176 return_schema.real_type.getElementType(), torch.TensorType 1177 ) 1178 user_node = _output_node_at_index(node, idx) 1179 assert user_node is not None 1180 1181 args = [] 1182 for i, m in enumerate(meta): 1183 if m is None: 1184 continue 1185 sub_user_node = _output_node_at_index(user_node, i) 1186 assert sub_user_node is not None, f"No user found at index {i}" 1187 1188 args.append(self.serialize_tensor_output(sub_user_node.name, m)) 1189 output_arguments.append(Argument.create(as_tensors=args)) 1190 elif isinstance(meta, (int, SymInt)): 1191 user_node = _output_node_at_index(node, idx) 1192 name = ( 1193 user_node.name 1194 if user_node is not None 1195 else f"{node.name}_unused_{idx}" 1196 ) 1197 output_arguments.append(self.serialize_output(name, meta)) 1198 else: 1199 raise ValueError( 1200 f"Unhandled output type {type(meta)} from node {node.format_node()}" 1201 ) 1202 1203 return output_arguments 1204 1205 def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: 1206 """ 1207 For serializing HOO outputs since HOOs do not have a schema. 1208 """ 1209 meta_val = node.meta["val"] 1210 1211 if isinstance(meta_val, tuple): 1212 # Note: Since we don't have a schema, we just serialize all tuple 1213 # outputs to be a list of values. Even if the output is supposed to 1214 # be a tensor list (Tensor[]), we will serialize it to be a list of 1215 # tensors (Tensor, Tensor, Tensor). An exception is that if there's 1216 # a singleton tensor, we will serialize this to be a singleton 1217 # tensor list so that the deserializer knows to insert getitem nodes. 1218 1219 if len(meta_val) == 1: 1220 assert isinstance(meta_val[0], torch.Tensor) 1221 user_node = _output_node_at_index(node, 0) 1222 name = ( 1223 user_node.name if user_node is not None else f"{node.name}_unused_0" 1224 ) 1225 return [ 1226 Argument.create( 1227 as_tensors=[self.serialize_tensor_output(name, meta_val[0])] 1228 ) 1229 ] 1230 1231 outputs = [] 1232 for i, element_meta_val in enumerate(meta_val): 1233 user_node = _output_node_at_index(node, i) 1234 if isinstance(element_meta_val, list): 1235 # e.g "-> Tensor[]" 1236 assert user_node is not None 1237 1238 tensors = [] 1239 for j, m in enumerate(element_meta_val): 1240 if not isinstance(m, torch.Tensor): 1241 raise SerializeError( 1242 f"Serialize list output with type {type(m)} nyi" 1243 ) 1244 1245 sub_user_node = _output_node_at_index(user_node, j) 1246 name = ( 1247 sub_user_node.name 1248 if sub_user_node is not None 1249 else f"{user_node.name}_unused_{j}" 1250 ) 1251 tensors.append(self.serialize_tensor_output(name, m)) 1252 outputs.append(Argument.create(as_tensors=tensors)) 1253 1254 else: 1255 name = ( 1256 user_node.name 1257 if user_node is not None 1258 else f"{node.name}_unused_{i}" 1259 ) 1260 1261 outputs.append(self.serialize_output(name, element_meta_val)) 1262 1263 return outputs 1264 else: 1265 return [self.serialize_output(node.name, meta_val)] 1266 1267 def serialize_output(self, name: str, meta_val: Any) -> Argument: 1268 # Check single value return 1269 if meta_val is None: 1270 return Argument.create(as_none=()) 1271 if isinstance(meta_val, torch.Tensor): 1272 # e.g "-> Tensor" 1273 return Argument.create( 1274 as_tensor=self.serialize_tensor_output(name, meta_val) 1275 ) 1276 elif isinstance(meta_val, (int, torch.SymInt)): 1277 # e.g "-> SymInt" 1278 return Argument.create( 1279 as_sym_int=self.serialize_sym_int_output(name, meta_val) 1280 ) 1281 elif isinstance(meta_val, torch.SymBool): 1282 # e.g "-> SymBool" 1283 return Argument.create( 1284 as_sym_bool=self.serialize_sym_bool_output(name, meta_val) 1285 ) 1286 1287 # list outputs should've been handled earlier 1288 raise SerializeError(f"Unable to serialize output {meta_val}") 1289 1290 def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: 1291 meta_val = node.meta["val"] 1292 1293 idx_to_name = {} 1294 for user in node.users: 1295 assert ( 1296 user.target is operator.getitem 1297 ), f"User node {user} of {node} is incorrect" 1298 idx_to_name[user.args[1]] = user.name 1299 1300 for idx, _ in enumerate(meta_val): 1301 # FX does not emit a getitem node for any outputs that are unused. 1302 # However, we need a name for them so that the number of outputs will 1303 # correctly match the schema. Just assign a dummy name. 1304 if idx not in idx_to_name: 1305 idx_to_name[idx] = f"{node.name}_unused_{idx}" 1306 1307 arg_list = [] 1308 for i, element_meta_val in enumerate(meta_val): 1309 arg_list.append( 1310 self.serialize_tensor_output(idx_to_name[i], element_meta_val) 1311 ) 1312 1313 return arg_list 1314 1315 def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: 1316 assert isinstance(graph_module, torch.fx.GraphModule) 1317 for node in graph_module.graph.nodes: 1318 try: 1319 getattr(self, f"handle_{node.op}")(node) 1320 except Exception as e: 1321 raise SerializeError( 1322 f"Failed serializing node {node} in graph: {node.format_node()}" 1323 ) from e 1324 1325 return Graph( 1326 inputs=self.graph_state.inputs, 1327 nodes=self.graph_state.nodes, 1328 tensor_values=self.graph_state.tensor_values, 1329 sym_int_values=self.graph_state.sym_int_values, 1330 sym_bool_values=self.graph_state.sym_bool_values, 1331 custom_obj_values=self.graph_state.custom_obj_values, 1332 outputs=self.graph_state.outputs, 1333 is_single_tensor_return=self.graph_state.is_single_tensor_return, 1334 ) 1335 1336 def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: 1337 graph = self.serialize_graph(graph_module) 1338 1339 return GraphModule( 1340 graph=graph, 1341 signature=self.serialize_signature(self.graph_signature), 1342 module_call_graph=self.serialize_module_call_graph(self.module_call_graph), 1343 ) 1344 1345 1346class ExportedProgramSerializer: 1347 def __init__(self, opset_version: Optional[Dict[str, int]] = None): 1348 self.opset_version: Dict[str, int] = {} 1349 if opset_version: 1350 self.opset_version.update(opset_version) 1351 if "aten" not in self.opset_version: 1352 self.opset_version["aten"] = torch._C._get_max_operator_version() 1353 1354 def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: 1355 """ 1356 Args: 1357 exported_program: Exported Program to serialize 1358 """ 1359 exported_program._validate() 1360 1361 gm_serializer = GraphModuleSerializer( 1362 exported_program.graph_signature, exported_program.module_call_graph 1363 ) 1364 serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) 1365 serialized_range_constraints = serialize_range_constraints( 1366 exported_program.range_constraints 1367 ) 1368 1369 # TODO: Directly serialize exported_program.constants once 1370 # CustomClassHolders get stored in the ExportedProgram rather than in 1371 # the graph 1372 constants = {} 1373 for n, c in gm_serializer.custom_objs.items(): 1374 constants[n] = c 1375 for n, t in exported_program.constants.items(): 1376 assert n not in constants 1377 constants[n] = t 1378 1379 additional_kwargs = {} 1380 if hasattr(exported_program, "verifiers"): 1381 additional_kwargs["verifiers"] = [ 1382 v.dialect for v in exported_program.verifiers 1383 ] 1384 elif hasattr(exported_program, "dialect"): 1385 additional_kwargs["dialect"] = exported_program.dialect 1386 serialized_ep = ExportedProgram( 1387 graph_module=serialized_graph_module, 1388 opset_version=self.opset_version, 1389 range_constraints=serialized_range_constraints, 1390 schema_version=SchemaVersion( 1391 major=SCHEMA_VERSION[0], 1392 minor=SCHEMA_VERSION[1], 1393 ), 1394 **additional_kwargs, 1395 ) 1396 1397 # Test canonical form is well defined. 1398 canonicalize(serialized_ep) 1399 1400 return _SerializedProgram( 1401 serialized_ep, 1402 serialize_torch_artifact(exported_program.state_dict), 1403 serialize_torch_artifact(constants), 1404 serialize_torch_artifact(exported_program.example_inputs), 1405 ) 1406 1407 1408class GraphModuleDeserializer: 1409 @dataclasses.dataclass 1410 class Result: 1411 graph_module: torch.fx.GraphModule 1412 signature: ep.ExportGraphSignature 1413 module_call_graph: List[ep.ModuleCallEntry] 1414 names_to_symbols: Dict[str, sympy.Symbol] 1415 state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] 1416 constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]] 1417 example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]] 1418 1419 def __init__(self): 1420 self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} 1421 self.serialized_name_to_meta: Dict[str, MetaType] = {} 1422 self.graph = torch.fx.Graph() 1423 self.module = torch.nn.Module() 1424 1425 @contextmanager 1426 def save_graph_module(self) -> Iterator[None]: 1427 saved = ( 1428 self.graph, 1429 self.module, 1430 self.serialized_name_to_node, 1431 self.serialized_name_to_meta, 1432 ) 1433 self.graph = torch.fx.Graph() 1434 self.module = torch.nn.Module() 1435 self.serialized_name_to_node = {} 1436 self.serialized_name_to_meta = {} 1437 try: 1438 yield 1439 finally: 1440 ( 1441 self.graph, 1442 self.module, 1443 self.serialized_name_to_node, 1444 self.serialized_name_to_meta, 1445 ) = saved 1446 1447 def deserialize_operator(self, serialized_target: str): 1448 if serialized_target.startswith( 1449 "_operator" 1450 ): # TODO(zhxchen17) Follow up on this. 1451 module = operator 1452 serialized_target_names = serialized_target.split(".")[1:] 1453 elif serialized_target.startswith("torch"): 1454 module = torch # type: ignore[misc] 1455 serialized_target_names = serialized_target.split(".")[1:] 1456 else: # TODO(zhxchen17) Don't catch all here. 1457 return serialized_target 1458 1459 target = module 1460 for name in serialized_target_names: 1461 if not hasattr(target, name): 1462 return serialized_target 1463 else: 1464 target = getattr(target, name) 1465 return target 1466 1467 def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: 1468 val = s.value 1469 if s.type == "as_expr": 1470 if val.hint is None: 1471 hint = None 1472 else: 1473 assert val.hint.type == "as_int" 1474 hint = val.hint.value 1475 1476 if val.expr_str in self.symbol_name_to_symbol: 1477 sym = self.symbol_name_to_symbol[val.expr_str] 1478 if ( 1479 isinstance(sym, sympy.Symbol) 1480 and sym not in self.shape_env.var_to_val 1481 ): 1482 if hint is not None: 1483 self.shape_env.add_var_to_val(sym, hint) 1484 else: 1485 sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) 1486 # NOTE(avik): Assumptions on symbols are not explicitly serialized. 1487 # This seems dangerous: it might cause unknown differences in shape env behavior 1488 # on deserialization? Probably deserves a follow-up. 1489 1490 # Here we force symbols corresponding to SymInts to be at least integers. 1491 # Otherwise some expressions that the shape env would otherwise evaluate to False, 1492 # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. 1493 sym = sym.subs( 1494 {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} 1495 ) 1496 if isinstance(sym, sympy.Symbol): 1497 self.symbol_name_to_symbol[val.expr_str] = sym 1498 if hint is not None: 1499 self.shape_env.add_var_to_val(sym, hint) 1500 1501 if vr := self.symbol_name_to_range.get(val.expr_str): 1502 self.shape_env.constrain_symbol_range( 1503 sym, 1504 compiler_min=vr.lower, # type: ignore[arg-type] 1505 compiler_max=vr.upper, # type: ignore[arg-type] 1506 ) 1507 else: 1508 # Placeholders, in particular, can have shapes as symbolic expressions. 1509 # We need to populate the shape env with the range constraints of their 1510 # free symbols, otherwise evaluating such expressions will error. 1511 self.symbol_name_to_symbol[val.expr_str] = sym 1512 free_symbols = sym.free_symbols 1513 for s in free_symbols: 1514 if s.name not in self.symbol_name_to_symbol: 1515 self.symbol_name_to_symbol[s.name] = s 1516 if vr := self.symbol_name_to_range.get(s.name): 1517 self.shape_env.constrain_symbol_range( 1518 s, 1519 compiler_min=vr.lower, # type: ignore[arg-type] 1520 compiler_max=vr.upper, # type: ignore[arg-type] 1521 ) 1522 1523 return self.shape_env.create_symintnode(sym, hint=hint) 1524 elif s.type == "as_int": 1525 assert isinstance(val, int) 1526 return val 1527 else: 1528 raise SerializeError( 1529 f"SymInt has invalid field type {s.type} with value {s.value}" 1530 ) 1531 1532 def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: 1533 val = s.value 1534 if s.type == "as_expr": 1535 expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) 1536 return self.shape_env.create_symboolnode(expr) 1537 elif s.type == "as_bool": 1538 assert isinstance(val, bool) 1539 return val 1540 else: 1541 raise SerializeError( 1542 f"SymBool has invalid field type {s.type} with value {s.value}" 1543 ) 1544 1545 def deserialize_tensor_meta( 1546 self, 1547 tensor_meta: TensorMeta, 1548 ) -> FakeTensor: 1549 with self.fake_tensor_mode: 1550 return cast( 1551 FakeTensor, 1552 torch.empty_strided( 1553 tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] 1554 tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] 1555 device=deserialize_device(tensor_meta.device), 1556 dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], 1557 ), 1558 ) 1559 1560 def deserialize_script_obj_meta( 1561 self, script_obj_meta: CustomObjArgument 1562 ) -> ep.CustomObjArgument: 1563 return ep.CustomObjArgument( 1564 name=script_obj_meta.name, 1565 class_fqn=script_obj_meta.class_fqn, 1566 ) 1567 1568 def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: 1569 if output.type == "as_tensor": 1570 return self.serialized_name_to_node[output.as_tensor.name] 1571 elif output.type == "as_sym_int": 1572 return self.serialized_name_to_node[output.as_sym_int.as_name] 1573 elif output.type == "as_sym_bool": 1574 return self.serialized_name_to_node[output.as_sym_bool.as_name] 1575 elif output.type == "as_int": 1576 return output.as_int 1577 elif output.type == "as_none": 1578 return None 1579 else: 1580 raise SerializeError(f"Unable to deserialize output node {output}") 1581 1582 def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: 1583 # Handle the tensor metas. 1584 for name, tensor_value in serialized_graph.tensor_values.items(): 1585 meta_val = self.deserialize_tensor_meta(tensor_value) 1586 self.serialized_name_to_meta[name] = meta_val 1587 1588 for name, sym_int_value in serialized_graph.sym_int_values.items(): 1589 self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value) 1590 1591 for name, sym_bool_value in serialized_graph.sym_bool_values.items(): 1592 self.serialized_name_to_meta[name] = self.deserialize_sym_bool( 1593 sym_bool_value 1594 ) 1595 1596 for name, script_obj_meta in serialized_graph.custom_obj_values.items(): 1597 self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( 1598 script_obj_meta 1599 ) 1600 1601 # Inputs: convert to placeholder nodes in FX. 1602 for i, input_ in enumerate(serialized_graph.inputs): 1603 if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"): 1604 if input_.type == "as_sym_int": 1605 node_name = input_.value.as_name 1606 else: 1607 node_name = input_.value.name 1608 placeholder_node = self.graph.placeholder(node_name) 1609 # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) 1610 # we will overwrite it 1611 placeholder_node.name = node_name 1612 self.sync_fx_node(node_name, placeholder_node) 1613 elif input_.type in ( 1614 "as_int", 1615 "as_float", 1616 "as_bool", 1617 "as_none", 1618 "as_string", 1619 ): 1620 node_name = self.signature.input_specs[i].arg.name 1621 placeholder_node = self.graph.placeholder(node_name) 1622 placeholder_node.meta["val"] = self.deserialize_input(input_) 1623 else: 1624 raise SerializeError(f"Invalid input type {input_}") 1625 1626 # Nodes: convert to call_function nodes. 1627 for serialized_node in serialized_graph.nodes: 1628 try: 1629 target = self.deserialize_operator(serialized_node.target) 1630 self.deserialize_node(serialized_node, target) 1631 1632 except Exception as e: 1633 raise SerializeError( 1634 f"Failed deserializing node {serialized_node}" 1635 ) from e 1636 1637 # Outputs: convert to a single `output` node. 1638 outputs = [] 1639 for output in serialized_graph.outputs: 1640 outputs.append(self.deserialize_graph_output(output)) 1641 1642 if serialized_graph.is_single_tensor_return: 1643 assert len(outputs) == 1 1644 outputs = outputs[0] # type: ignore[assignment] 1645 else: 1646 outputs = tuple(outputs) # type: ignore[assignment] 1647 1648 output_node = self.graph.output(outputs) 1649 1650 if serialized_graph.is_single_tensor_return: 1651 output_node.meta["val"] = output_node.args[0].meta["val"] 1652 else: 1653 output_node.meta["val"] = tuple( 1654 arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg 1655 for arg in output_node.args[0] 1656 ) 1657 1658 return self.graph 1659 1660 def deserialize_node(self, serialized_node: Node, target: Callable) -> None: 1661 if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS: 1662 name = serialized_node.outputs[0].value.as_name 1663 args = self.deserialize_sym_op_inputs(serialized_node.inputs) 1664 1665 fx_node = self.graph.create_node("call_function", target, args, {}, name) 1666 self.deserialize_sym_op_outputs(serialized_node, fx_node) 1667 1668 elif isinstance(target, torch._ops.HigherOrderOperator): 1669 args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) 1670 # If HOP returns a single tensor, name the 1671 # newly-created node after it. This ensures that these tensor values 1672 # have names that are consistent with serialized. 1673 # 1674 # HOPs don't have schema yet, just check the output lengths and as_tensor attribute 1675 name = ( 1676 serialized_node.outputs[0].as_tensor.name 1677 if len(serialized_node.outputs) == 1 1678 and hasattr(serialized_node.outputs[0], "as_tensor") 1679 else None 1680 ) 1681 fx_node = self.graph.create_node( 1682 "call_function", target, args, kwargs, name 1683 ) 1684 self.deserialize_outputs(serialized_node, fx_node) 1685 fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 1686 1687 elif isinstance(target, torch._ops.OpOverload): 1688 # For convenience: if this node returns a single tensor, name the 1689 # newly-created node after it. This ensures that these tensor values 1690 # have names that are consistent with serialized. 1691 name = ( 1692 serialized_node.outputs[0].as_tensor.name 1693 if _is_single_tensor_return(target) 1694 else None # FX will generate a name for us. 1695 ) 1696 args, kwargs = self.deserialize_inputs(target, serialized_node) 1697 fx_node = self.graph.create_node( 1698 "call_function", target, args, kwargs, name 1699 ) 1700 self.deserialize_outputs(serialized_node, fx_node) 1701 else: 1702 raise SerializeError( 1703 f"Unsupported target type for node {serialized_node}: {target}" 1704 ) 1705 1706 fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) 1707 if ( 1708 fx_node.op not in ["placeholder", "output"] 1709 and "nn_module_stack" not in fx_node.meta 1710 ): 1711 fx_node.meta["nn_module_stack"] = ( 1712 {} 1713 ) # serialization throws away empty dicts 1714 1715 def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: 1716 if i.type == "user_input": 1717 return ep.InputSpec( 1718 kind=ep.InputKind.USER_INPUT, 1719 arg=self.deserialize_argument_spec(i.user_input.arg), 1720 target=None, 1721 ) 1722 elif i.type == "parameter": 1723 return ep.InputSpec( 1724 kind=ep.InputKind.PARAMETER, 1725 arg=ep.TensorArgument(name=i.parameter.arg.name), 1726 target=i.parameter.parameter_name, 1727 ) 1728 elif i.type == "buffer": 1729 return ep.InputSpec( 1730 kind=ep.InputKind.BUFFER, 1731 arg=ep.TensorArgument(name=i.buffer.arg.name), 1732 target=i.buffer.buffer_name, 1733 persistent=i.buffer.persistent, 1734 ) 1735 elif i.type == "tensor_constant": 1736 return ep.InputSpec( 1737 kind=ep.InputKind.CONSTANT_TENSOR, 1738 arg=ep.TensorArgument(name=i.tensor_constant.arg.name), 1739 target=i.tensor_constant.tensor_constant_name, 1740 ) 1741 elif i.type == "custom_obj": 1742 return ep.InputSpec( 1743 kind=ep.InputKind.CUSTOM_OBJ, 1744 arg=ep.CustomObjArgument( 1745 name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn 1746 ), 1747 target=i.custom_obj.custom_obj_name, 1748 ) 1749 elif i.type == "token": 1750 return ep.InputSpec( 1751 kind=ep.InputKind.TOKEN, 1752 arg=ep.TokenArgument(name=i.token.arg.name), 1753 target=None, 1754 ) 1755 elif i.type == "constant_input": 1756 return ep.InputSpec( 1757 kind=ep.InputKind.USER_INPUT, 1758 arg=ep.ConstantArgument( 1759 name=i.constant_input.name, 1760 value=self.deserialize_constant_input(i.constant_input.value), 1761 ), 1762 target=None, 1763 ) 1764 else: 1765 raise AssertionError(f"Unknown input spec {i}") 1766 1767 def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: 1768 if o.type == "user_output": 1769 return ep.OutputSpec( 1770 kind=ep.OutputKind.USER_OUTPUT, 1771 arg=self.deserialize_argument_spec(o.user_output.arg), 1772 target=None, 1773 ) 1774 elif o.type == "loss_output": 1775 return ep.OutputSpec( 1776 kind=ep.OutputKind.LOSS_OUTPUT, 1777 arg=ep.TensorArgument(name=o.loss_output.arg.name), 1778 target=None, 1779 ) 1780 elif o.type == "buffer_mutation": 1781 return ep.OutputSpec( 1782 kind=ep.OutputKind.BUFFER_MUTATION, 1783 arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), 1784 target=o.buffer_mutation.buffer_name, 1785 ) 1786 elif o.type == "gradient_to_parameter": 1787 return ep.OutputSpec( 1788 kind=ep.OutputKind.GRADIENT_TO_PARAMETER, 1789 arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), 1790 target=o.gradient_to_parameter.parameter_name, 1791 ) 1792 elif o.type == "gradient_to_user_input": 1793 return ep.OutputSpec( 1794 kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, 1795 arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), 1796 target=o.gradient_to_user_input.user_input_name, 1797 ) 1798 elif o.type == "user_input_mutation": 1799 return ep.OutputSpec( 1800 kind=ep.OutputKind.USER_INPUT_MUTATION, 1801 arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), 1802 target=o.user_input_mutation.user_input_name, 1803 ) 1804 elif o.type == "token": 1805 return ep.OutputSpec( 1806 kind=ep.OutputKind.TOKEN, 1807 arg=ep.TokenArgument(name=o.token.arg.name), 1808 target=None, 1809 ) 1810 else: 1811 raise AssertionError(f"Unknown output spec {o}") 1812 1813 def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: 1814 return ep.ExportGraphSignature( 1815 input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], 1816 output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs], 1817 ) 1818 1819 def deserialize( 1820 self, 1821 serialized_graph_module: GraphModule, 1822 serialized_state_dict: Union[Dict[str, torch.Tensor], bytes], 1823 constants: Union[Dict[str, Any], bytes], 1824 example_inputs: Optional[ 1825 Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes] 1826 ] = None, 1827 symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, 1828 ) -> Result: 1829 global _CURRENT_DESERIALIZER 1830 current_deserializer_state = _CURRENT_DESERIALIZER.copy() 1831 _CURRENT_DESERIALIZER.append(self) 1832 try: 1833 self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) 1834 self.fake_tensor_mode = FakeTensorMode( 1835 allow_fallback_kernels=False, 1836 allow_non_fake_inputs=True, 1837 shape_env=self.shape_env, 1838 ) 1839 self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} 1840 self.constants = deserialize_torch_artifact(constants) 1841 self.signature = self.deserialize_signature( 1842 serialized_graph_module.signature 1843 ) 1844 1845 # deserialization does analysis with checks on 0/1, so we create fake range constraints and 1846 # restore the original range constraints afterwards 1847 self.symbol_name_to_range = {} 1848 if symbol_name_to_range: 1849 for k, vr in symbol_name_to_range.items(): 1850 if math.isinf(vr.lower) and vr.lower < 0: 1851 lower = -math.inf 1852 elif math.isinf(vr.lower): 1853 lower = math.inf 1854 else: 1855 lower = int(vr.lower) 1856 1857 if vr.upper >= 2: # max is >= 2, not sym bool range 1858 lower = max(2, lower) 1859 self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( 1860 _int_to_sympy_int(lower), vr.upper 1861 ) 1862 1863 if example_inputs is not None and len(example_inputs) > 0: 1864 self.example_inputs = deserialize_torch_artifact(example_inputs) 1865 else: 1866 self.example_inputs = None 1867 self.deserialize_graph(serialized_graph_module.graph) 1868 1869 module_call_graph = self.deserialize_module_call_graph( 1870 serialized_graph_module.module_call_graph 1871 ) 1872 return GraphModuleDeserializer.Result( 1873 graph_module=ep._create_graph_module_for_export( 1874 self.module, self.graph 1875 ), 1876 signature=self.signature, 1877 module_call_graph=module_call_graph, 1878 names_to_symbols=self.symbol_name_to_symbol, 1879 state_dict=deserialize_torch_artifact(serialized_state_dict), 1880 constants=self.constants, 1881 example_inputs=self.example_inputs, 1882 ) 1883 finally: 1884 _CURRENT_DESERIALIZER.pop() 1885 assert current_deserializer_state == _CURRENT_DESERIALIZER 1886 1887 def sync_fx_node(self, name: str, fx_node: torch.fx.Node): 1888 if name in self.serialized_name_to_node: 1889 raise SerializeError(f"Node {name} has already been deserialized before.") 1890 self.serialized_name_to_node[name] = fx_node 1891 assert "val" not in fx_node.meta 1892 fx_node.meta["val"] = self.serialized_name_to_meta[name] 1893 1894 def deserialize_sym_op_inputs(self, inputs): 1895 return tuple(self.deserialize_input(input.arg) for input in inputs) 1896 1897 def deserialize_inputs(self, target: torch._ops.OpOverload, serialized_node: Node): 1898 schema_args = target._schema.arguments 1899 actual_args = { 1900 input.name: self.deserialize_input(input.arg) 1901 for input in serialized_node.inputs 1902 } 1903 args = [] 1904 kwargs = {} 1905 for schema_arg in schema_args: 1906 is_positional = ( 1907 not schema_arg.has_default_value() and not schema_arg.kwarg_only 1908 ) 1909 if is_positional: 1910 args.append(actual_args[schema_arg.name]) 1911 else: 1912 if schema_arg.name in actual_args: 1913 kwargs[schema_arg.name] = actual_args[schema_arg.name] 1914 return tuple(args), kwargs 1915 1916 def deserialize_hoo_inputs(self, inputs: List[NamedArgument]): 1917 """ 1918 For deserializing HOO inputs since HOOs do not have a schema. 1919 """ 1920 args = [] 1921 kwargs = {} 1922 for input_ in inputs: 1923 if input_.name != "": 1924 kwargs[input_.name] = self.deserialize_input(input_.arg) 1925 else: 1926 args.append(self.deserialize_input(input_.arg)) 1927 return (tuple(args), kwargs) 1928 1929 def deserialize_input(self, inp: Argument) -> Any: 1930 value = inp.value 1931 typ_ = inp.type 1932 if typ_ == "as_none": 1933 # None should converted as None, but is encoded as bool in serialized 1934 # Convert serialized object to torch equivalent 1935 return None 1936 elif typ_ == "as_tensor": 1937 return self.serialized_name_to_node[inp.as_tensor.name] 1938 elif typ_ == "as_scalar_type": 1939 return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] 1940 elif typ_ == "as_memory_format": 1941 return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] 1942 elif typ_ == "as_layout": 1943 return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] 1944 elif typ_ == "as_graph": 1945 assert isinstance(value, GraphArgument) 1946 with self.save_graph_module(): 1947 self.deserialize_graph(value.graph) 1948 submodule = ep._create_graph_module_for_export(self.module, self.graph) 1949 self.module.register_module(value.name, submodule) 1950 return self.graph.create_node( 1951 "get_attr", 1952 value.name, 1953 name=value.name, 1954 ) 1955 elif typ_ == "as_device": 1956 return deserialize_device(inp.as_device) 1957 elif typ_ == "as_int": 1958 return inp.as_int 1959 elif typ_ == "as_float": 1960 return inp.as_float 1961 elif typ_ == "as_bool": 1962 return inp.as_bool 1963 elif typ_ == "as_string": 1964 return inp.as_string 1965 elif typ_ == "as_sym_int": 1966 return self.deserialize_sym_argument(inp.as_sym_int) 1967 elif typ_ == "as_sym_bool": 1968 return self.deserialize_sym_argument(inp.as_sym_bool) 1969 elif isinstance(value, list): 1970 if len(value) == 0: 1971 return [] 1972 elif typ_ == "as_tensors": 1973 result = [] 1974 for arg in value: 1975 result.append(self.serialized_name_to_node[arg.name]) 1976 return result 1977 elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): 1978 # convert from serialized.python.types.List to python list 1979 return list(value) 1980 elif typ_ in ("as_sym_ints", "as_sym_bools"): 1981 return [self.deserialize_sym_argument(arg) for arg in value] 1982 elif typ_ == "as_optional_tensors": 1983 1984 def deserialize_optional_tensor_args(a): 1985 if a.type == "as_none": 1986 return None 1987 elif a.type == "as_tensor": 1988 return self.serialized_name_to_node[a.value.name] 1989 else: 1990 raise SerializeError(f"Unhandled argument {inp}") 1991 1992 return list(map(deserialize_optional_tensor_args, value)) 1993 else: 1994 raise SerializeError(f"Unhandled argument {inp}") 1995 elif typ_ == "as_custom_obj": 1996 if inp.as_custom_obj.name in self.serialized_name_to_node: 1997 # Custom object has been lifted as an input 1998 return self.serialized_name_to_node[inp.as_custom_obj.name] 1999 return self.constants[inp.as_custom_obj.name] 2000 elif typ_ == "as_operator": 2001 return self.deserialize_operator(inp.as_operator) 2002 else: 2003 raise SerializeError(f"Unhandled argument {inp}") 2004 2005 def deserialize_constant_input(self, inp: ConstantValue) -> Any: 2006 if inp.type == "as_int": 2007 return int(inp.as_int) 2008 elif inp.type == "as_float": 2009 return float(inp.as_float) 2010 elif inp.type == "as_string": 2011 return str(inp.as_string) 2012 elif inp.type == "as_bool": 2013 return bool(inp.as_bool) 2014 elif inp.type == "as_none": 2015 return None 2016 else: 2017 raise SerializeError(f"Unhandled constant argument {inp} to deserialize") 2018 2019 def deserialize_sym_argument(self, sym_arg): 2020 if isinstance(sym_arg, SymIntArgument): 2021 if sym_arg.type == "as_int": 2022 return sym_arg.as_int 2023 elif sym_arg.type == "as_name": 2024 return self.serialized_name_to_node[sym_arg.as_name] 2025 elif isinstance(sym_arg, SymBoolArgument): 2026 if sym_arg.type == "as_bool": 2027 return sym_arg.as_bool 2028 elif sym_arg.type == "as_name": 2029 return self.serialized_name_to_node[sym_arg.as_name] 2030 raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") 2031 2032 def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): 2033 self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) 2034 2035 def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): 2036 # Check single value return 2037 if len(serialized_node.outputs) == 0: 2038 return 2039 if ( 2040 len(serialized_node.outputs) == 1 2041 and serialized_node.outputs[0].type == "as_tensor" 2042 ): 2043 self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) 2044 return 2045 elif len(serialized_node.outputs) == 1 and isinstance( 2046 serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument) 2047 ): 2048 self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) 2049 return 2050 2051 self.deserialize_multiple_outputs(serialized_node, fx_node) 2052 2053 def deserialize_multiple_outputs( 2054 self, serialized_node: Node, fx_node: torch.fx.Node 2055 ) -> None: 2056 deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) 2057 2058 def generate_getitem( 2059 meta_val, 2060 fx_node: torch.fx.Node, 2061 arg: Union[TensorArgument, SymIntArgument], 2062 idx: int, 2063 ): 2064 if isinstance(arg, TensorArgument): 2065 name = arg.name 2066 elif isinstance(arg, SymIntArgument): 2067 name = arg.as_name 2068 else: 2069 raise AssertionError( 2070 f"generate_getitem got unknown argument type {type(arg)}" 2071 ) 2072 individual_output = self.graph.create_node( 2073 "call_function", 2074 operator.getitem, 2075 (fx_node, idx), 2076 name=name, 2077 ) 2078 self.sync_fx_node(name, individual_output) 2079 meta_val.append(self.serialized_name_to_meta[name]) 2080 # The derived `getitem` nodes should have the same stacktrace as the 2081 # original `fx_node` 2082 individual_output.meta.update(deserialized_metadata) 2083 2084 def generate_getitems(meta_val, fx_node: torch.fx.Node, args): 2085 for idx, arg in enumerate(args): 2086 if isinstance(arg, Argument): 2087 arg = arg.value 2088 if isinstance(arg, (TensorArgument, SymIntArgument)): 2089 generate_getitem(meta_val, fx_node, arg, idx) 2090 elif isinstance(arg, (list, tuple)): 2091 list_output = self.graph.create_node( 2092 "call_function", 2093 operator.getitem, 2094 (fx_node, idx), 2095 ) 2096 meta_val.append([]) 2097 generate_getitems(meta_val[-1], list_output, arg) 2098 list_output.meta.update(deserialized_metadata) 2099 list_output.meta["val"] = meta_val[-1] 2100 else: 2101 raise NotImplementedError(f"Unimplemented node output type: {arg}") 2102 2103 # Convert multiple return types to FX format. 2104 # In FX, each node only returns one value. So in order to represent 2105 # multiple return values, we have to emit a `getitem` node for each 2106 # return value. 2107 # This performs the inverse mapping of the `serialize_outputs` call in 2108 # serialization, see [NOTE: Multiple outputs] 2109 meta_val: List[Any] = [] 2110 if len(serialized_node.outputs) == 1: 2111 assert isinstance(serialized_node.outputs[0].value, list) 2112 assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) 2113 generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors) 2114 else: 2115 generate_getitems(meta_val, fx_node, serialized_node.outputs) 2116 2117 # also update the metaval for `fx_node` to be a list(meta) 2118 fx_node.meta["val"] = tuple(meta_val) 2119 self.serialized_name_to_node[fx_node.name] = fx_node 2120 2121 def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: 2122 ret: Dict[str, Any] = {} 2123 if stack_trace := metadata.get("stack_trace"): 2124 ret["stack_trace"] = stack_trace 2125 2126 def deserialize_meta_func(serialized_target: str): 2127 module = None 2128 if serialized_target.startswith("torch.nn"): 2129 module = torch.nn 2130 serialized_target_names = serialized_target.split(".")[2:] 2131 elif serialized_target.startswith("torch"): 2132 module = torch 2133 serialized_target_names = serialized_target.split(".")[1:] 2134 else: 2135 return self.deserialize_operator(serialized_target) 2136 2137 target = module 2138 for name in serialized_target_names: 2139 if not hasattr(target, name): 2140 return serialized_target 2141 else: 2142 target = getattr(target, name) 2143 return target 2144 2145 if nn_module_stack_str := metadata.get("nn_module_stack"): 2146 # Originally serialized to "key,orig_path,type_str" 2147 def import_nn_module_stack(key, path, ty): 2148 return key, (path, ty) 2149 2150 # Helper function that splits strings by commas except for those 2151 # encapsulated by parens, which are valid traces. 2152 # TODO: Currently this is needed due to indexing Sequential 2153 # layers introducing names in the form "layer.slice(1, None, None)". 2154 # If that naming is improved, this fancier splitting can probably be 2155 # reverted to a simple split by comma. 2156 def metadata_split(metadata): 2157 # Remove the parentheses and commas inside them 2158 metadata = re.sub(r"\(.*?\)", "", metadata) 2159 # Split the string by comma, except for those inside parentheses 2160 return re.split(r"(?<!\()\s*,\s*(?!\()", metadata) 2161 2162 nn_module_stack = dict( 2163 import_nn_module_stack(*metadata_split(item)) 2164 for item in nn_module_stack_str.split(ST_DELIMITER) 2165 ) 2166 ret["nn_module_stack"] = nn_module_stack 2167 2168 if source_fn_st_str := metadata.get("source_fn_stack"): 2169 # Originally serializes to "fx_node_name,op_str" 2170 source_fn_st = [] 2171 for source_fn_str in source_fn_st_str.split(ST_DELIMITER): 2172 name, target_str = source_fn_str.split(",") 2173 source_fn_st.append((name, deserialize_meta_func(target_str))) 2174 ret["source_fn_stack"] = source_fn_st 2175 2176 if torch_fn_str := metadata.get("torch_fn"): 2177 ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER)) 2178 return ret 2179 2180 def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: 2181 if x.type == "as_tensor": 2182 return ep.TensorArgument(name=x.as_tensor.name) 2183 elif x.type == "as_sym_int": 2184 return ep.SymIntArgument(name=x.as_sym_int.as_name) 2185 elif x.type == "as_custom_obj": 2186 return ep.ConstantArgument( 2187 name=x.as_custom_obj.name, value=self.deserialize_input(x) 2188 ) 2189 else: 2190 return ep.ConstantArgument(name="", value=self.deserialize_input(x)) 2191 2192 def deserialize_module_call_signature( 2193 self, module_call_signature: ModuleCallSignature 2194 ) -> ep.ModuleCallSignature: 2195 return ep.ModuleCallSignature( 2196 inputs=[ 2197 self.deserialize_argument_spec(x) for x in module_call_signature.inputs 2198 ], 2199 outputs=[ 2200 self.deserialize_argument_spec(x) for x in module_call_signature.outputs 2201 ], 2202 in_spec=treespec_loads(module_call_signature.in_spec), 2203 out_spec=treespec_loads(module_call_signature.out_spec), 2204 ) 2205 2206 def deserialize_module_call_graph( 2207 self, module_call_graph: List[ModuleCallEntry] 2208 ) -> List[ep.ModuleCallEntry]: 2209 return [ 2210 ep.ModuleCallEntry( 2211 fqn=entry.fqn, 2212 signature=( 2213 self.deserialize_module_call_signature(entry.signature) 2214 if entry.signature 2215 else None 2216 ), 2217 ) 2218 for entry in module_call_graph 2219 ] 2220 2221 2222class ExportedProgramDeserializer: 2223 def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): 2224 self.expected_opset_version: Dict[str, int] = {} 2225 if expected_opset_version: 2226 self.expected_opset_version.update(expected_opset_version) 2227 if "aten" not in self.expected_opset_version: 2228 self.expected_opset_version["aten"] = torch._C._get_max_operator_version() 2229 2230 def deserialize_range_constraints( 2231 self, 2232 symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges], 2233 symbol_name_to_symbol: Dict[str, sympy.Symbol], 2234 ) -> Dict[sympy.Symbol, ValueRanges]: 2235 range_constraints = {} 2236 for k, v in symbol_name_to_range.items(): 2237 if symbol := symbol_name_to_symbol.get(k): 2238 range_constraints[symbol] = v # type: ignore[arg-type] 2239 else: 2240 log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004 2241 return range_constraints 2242 2243 def deserialize( 2244 self, 2245 exported_program: ExportedProgram, 2246 state_dict: Union[Dict[str, torch.Tensor], bytes], 2247 constants: Union[Dict[str, torch.Tensor], bytes], 2248 example_inputs: Optional[ 2249 Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes] 2250 ] = None, 2251 ) -> ep.ExportedProgram: 2252 assert isinstance(exported_program, ExportedProgram) 2253 version = exported_program.schema_version 2254 2255 # TODO(zhxchen17) blocked on thrift schema refactor 2256 if version.major != SCHEMA_VERSION[0] and not ( 2257 version.major == 0 and version.minor == 0 2258 ): 2259 raise SerializeError( 2260 f"Serialized schema version {exported_program.schema_version} " 2261 f"does not match our current schema version {SCHEMA_VERSION}." 2262 ) 2263 2264 symbol_name_to_range = { 2265 k: symbolic_shapes.ValueRanges( 2266 _int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val) 2267 ) 2268 for k, v in exported_program.range_constraints.items() 2269 } 2270 res = GraphModuleDeserializer().deserialize( 2271 exported_program.graph_module, 2272 state_dict, 2273 constants, 2274 example_inputs, 2275 symbol_name_to_range, 2276 ) 2277 range_constraints = self.deserialize_range_constraints( 2278 symbol_name_to_range, 2279 res.names_to_symbols, 2280 ) 2281 model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version 2282 self._validate_model_opset_version(model_opset_version) 2283 2284 upgrader = GraphModuleOpUpgrader( 2285 self.expected_opset_version, model_opset_version 2286 ) 2287 2288 exported_program = ep.ExportedProgram( 2289 root=res.graph_module, 2290 graph=res.graph_module.graph, 2291 graph_signature=res.signature, 2292 state_dict=res.state_dict, # type: ignore[arg-type] 2293 range_constraints=range_constraints, 2294 module_call_graph=res.module_call_graph, 2295 example_inputs=res.example_inputs, 2296 verifier=load_verifier(exported_program.dialect), 2297 constants=res.constants, 2298 ) 2299 return upgrader.upgrade(exported_program) 2300 2301 def _validate_model_opset_version( 2302 self, model_opset_version: Optional[Dict[str, int]] 2303 ): 2304 """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version 2305 difference. 2306 E.g., model_opset_version = {"aten": 3, "custom": 4} 2307 expected_opset_version = {"aten": 4, "custom": 4} 2308 This means we can use an upgrader for ATen to reconcile the deserialized model. 2309 2310 The logic of this method: 2311 2312 For common op namespaces: 2313 1. if model version < expected version, this case can be handled by upgraders. 2314 2. if model version > expected version, we need downgraders but not implemented yet. 2315 3. if model version == expected version, we don't need extra handling. 2316 2317 For op namespace only in model_opset_version, we should give a warning because it is missing from 2318 expected_opset_version. 2319 """ 2320 if not model_opset_version: 2321 raise RuntimeError("Serialized model should have opset version.") 2322 common_namespaces = { 2323 key for key in model_opset_version if key in self.expected_opset_version 2324 } 2325 for namespace in common_namespaces: 2326 model_version = model_opset_version[namespace] 2327 assert isinstance( 2328 model_version, int 2329 ), f"model_opset_version value should be int, got {model_version}" 2330 2331 compiler_version = self.expected_opset_version[namespace] 2332 assert isinstance( 2333 compiler_version, int 2334 ), f"expected_opset_version value should be int, got {compiler_version}" 2335 2336 # TODO(larryliu0820): Add support for upgrader & downgrader 2337 if model_version != compiler_version: 2338 raise NotImplementedError( 2339 f"Model opset version {model_opset_version} doesn't match to compiler opset version " 2340 f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet." 2341 ) 2342 for namespace in model_opset_version: 2343 if namespace in common_namespaces: 2344 continue 2345 log.warning( 2346 "Compiler doesn't have a version table for op namespace: {ns}. ", 2347 extra={"ns": namespace}, 2348 ) 2349 2350 2351class EnumEncoder(json.JSONEncoder): 2352 def default(self, obj): 2353 if isinstance(obj, Enum): 2354 return obj.value 2355 if isinstance(obj, bytes): 2356 return base64.b64encode(obj).decode("utf-8") 2357 return super().default(obj) 2358 2359 2360def _dataclass_to_dict(obj): 2361 if isinstance(obj, _Union): 2362 return {obj.type: _dataclass_to_dict(obj.value)} 2363 elif dataclasses.is_dataclass(obj): 2364 return { 2365 f.name: _dataclass_to_dict(getattr(obj, f.name)) 2366 for f in dataclasses.fields(obj) 2367 if not (f.default is None and getattr(obj, f.name) is None) 2368 } 2369 elif isinstance(obj, list): 2370 return [_dataclass_to_dict(x) for x in obj] 2371 elif isinstance(obj, tuple): 2372 return tuple(_dataclass_to_dict(x) for x in obj) 2373 elif isinstance(obj, dict): 2374 return {k: _dataclass_to_dict(v) for k, v in obj.items()} 2375 else: 2376 return obj 2377 2378 2379def serialize( 2380 exported_program: ep.ExportedProgram, 2381 opset_version: Optional[Dict[str, int]] = None, 2382) -> SerializedArtifact: 2383 serialized_program = ExportedProgramSerializer(opset_version).serialize( 2384 exported_program 2385 ) 2386 assert isinstance(serialized_program.exported_program, ExportedProgram) 2387 2388 json_program = json.dumps( 2389 _dataclass_to_dict(serialized_program.exported_program), cls=EnumEncoder 2390 ) 2391 json_bytes = json_program.encode("utf-8") 2392 artifact = SerializedArtifact( 2393 json_bytes, 2394 serialized_program.state_dict, 2395 serialized_program.constants, 2396 serialized_program.example_inputs, 2397 ) 2398 return artifact 2399 2400 2401def _dict_to_dataclass(cls, data): 2402 assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." 2403 if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): 2404 if data is None: 2405 return None 2406 ty_args = typing.get_args(cls) 2407 assert len(ty_args) == 2 2408 return _dict_to_dataclass(ty_args[0], data) 2409 elif isinstance(cls, type) and issubclass(cls, _Union): 2410 assert isinstance(data, dict) 2411 assert len(data) == 1 2412 _type = next(iter(data.keys())) 2413 _value = next(iter(data.values())) 2414 assert isinstance(_type, str) 2415 field_type = cls.__annotations__[_type] 2416 return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) 2417 elif dataclasses.is_dataclass(cls): 2418 obj = cls(**data) # type: ignore[assignment] 2419 type_hints = typing.get_type_hints(cls) 2420 for f in dataclasses.fields(cls): 2421 name = f.name 2422 new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) 2423 setattr(obj, name, new_field_obj) 2424 return obj 2425 elif isinstance(data, list): 2426 if len(data) == 0: 2427 return data 2428 d_type = typing.get_args(cls)[0] 2429 return [_dict_to_dataclass(d_type, d) for d in data] 2430 elif isinstance(data, dict): 2431 v_type = typing.get_args(cls)[1] 2432 return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()} 2433 return data 2434 2435 2436def deserialize( 2437 artifact: SerializedArtifact, 2438 expected_opset_version: Optional[Dict[str, int]] = None, 2439) -> ep.ExportedProgram: 2440 assert isinstance(artifact.exported_program, bytes) 2441 exported_program_str = artifact.exported_program.decode("utf-8") 2442 exported_program_dict = json.loads(exported_program_str) 2443 serialized_exported_program = _dict_to_dataclass( 2444 ExportedProgram, exported_program_dict 2445 ) 2446 return ExportedProgramDeserializer(expected_opset_version).deserialize( 2447 serialized_exported_program, 2448 artifact.state_dict, 2449 artifact.constants, 2450 artifact.example_inputs, 2451 ) 2452 2453 2454def _canonicalize_graph( 2455 sorted_inputs, sorted_outputs, graph 2456) -> Tuple[Graph, Dict[str, str]]: 2457 def _get_argument(a: Argument): 2458 if a.type == "as_none": 2459 return None 2460 elif a.type == "as_tensor": 2461 return a.as_tensor 2462 elif a.type == "as_tensors": 2463 return a.as_tensors 2464 elif a.type == "as_int": 2465 return None 2466 elif a.type == "as_ints": 2467 return None 2468 elif a.type == "as_float": 2469 return None 2470 elif a.type == "as_floats": 2471 return None 2472 elif a.type == "as_string": 2473 return None 2474 elif a.type == "as_strings": 2475 return None 2476 elif a.type == "as_sym_int": 2477 return a.as_sym_int 2478 elif a.type == "as_sym_ints": 2479 return a.as_sym_ints 2480 elif a.type == "as_scalar_type": 2481 return None 2482 elif a.type == "as_memory_format": 2483 return None 2484 elif a.type == "as_layout": 2485 return None 2486 elif a.type == "as_device": 2487 return None 2488 elif a.type == "as_bool": 2489 return None 2490 elif a.type == "as_bools": 2491 return None 2492 elif a.type == "as_sym_bool": 2493 return a.as_sym_bool 2494 elif a.type == "as_sym_bools": 2495 return a.as_sym_bools 2496 elif a.type == "as_graph": 2497 return None 2498 elif a.type == "as_optional_tensors": 2499 return a.as_optional_tensors 2500 elif a.type == "as_custom_obj": 2501 return None 2502 elif a.type == "as_operator": 2503 return None 2504 else: 2505 raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") 2506 2507 # Stage 1: Reorder named items. 2508 def for_args(f, a): 2509 assert isinstance(a, Argument) 2510 pytree.tree_map(f, _get_argument(a)) 2511 2512 def sort_nodes(nodes): 2513 @dataclass 2514 class Edges: 2515 outs: List[int] 2516 ins: int 2517 2518 graph_inputs: Set[str] = set() 2519 def_table: Dict[str, int] = {} 2520 edges: Dict[int, Edges] = {} 2521 candidates: List[Tuple[str, List[Tuple[str, List[int]]], int]] = [] 2522 rank: Dict[str, int] = {} 2523 ret: List[Node] = [] 2524 2525 def get_name(a) -> Optional[str]: 2526 if a is None: 2527 return None 2528 if isinstance(a, TensorArgument): 2529 return a.name 2530 elif isinstance(a, (SymIntArgument, SymBoolArgument)): 2531 if a.type == "as_name": 2532 return a.as_name 2533 elif a.type in ("as_int", "as_bool"): 2534 return None 2535 else: 2536 raise AssertionError(f"Unknown argument type: {a}") 2537 elif isinstance(a, OptionalTensorArgument): 2538 if a.type == "as_tensor": 2539 return a.as_tensor.name 2540 elif a.type == "as_none": 2541 return None 2542 else: 2543 raise AssertionError(f"Unknown optional tensor type: {a}") 2544 else: 2545 raise AssertionError(f"Unknown argument type: {a}") 2546 2547 for i in sorted_inputs: 2548 2549 def add_input(a): 2550 if s := get_name(a): 2551 graph_inputs.add(s) 2552 2553 for_args(add_input, i) 2554 2555 for idx, node in enumerate(nodes): 2556 2557 def add_def(a): 2558 if s := get_name(a): 2559 assert s not in def_table 2560 def_table[s] = idx 2561 2562 for o in node.outputs: 2563 for_args(add_def, o) 2564 2565 edges[idx] = Edges([], 0) 2566 2567 for idx, user in enumerate(nodes): 2568 2569 def add_edge(a): 2570 if s := get_name(a): 2571 if s not in def_table: 2572 assert s in graph_inputs 2573 return 2574 src = def_table[s] 2575 edges[src].outs.append(idx) 2576 edges[idx].ins += 1 2577 2578 for i in user.inputs: 2579 for_args(add_edge, i.arg) 2580 2581 def add_rank(a): 2582 if s := get_name(a): 2583 assert s not in rank 2584 rank[s] = len(rank) 2585 2586 def get_rank(a): 2587 if s := get_name(a): 2588 return rank[s] 2589 else: 2590 return -1 2591 2592 for i in sorted_inputs: 2593 for_args(add_rank, i) 2594 2595 def add_candidate(idx: int): 2596 def get_ranks(i): 2597 ranks = [] 2598 for_args(lambda x: ranks.append(get_rank(x)), i) 2599 return ranks 2600 2601 node = nodes[idx] 2602 args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] 2603 heapq.heappush(candidates, (node.target, args_rank, idx)) 2604 2605 for idx, e in edges.items(): 2606 if e.ins == 0: 2607 add_candidate(idx) 2608 2609 while len(candidates) > 0: 2610 _, _, idx = heapq.heappop(candidates) 2611 node = nodes[idx] 2612 for o in node.outputs: 2613 for_args(add_rank, o) 2614 ret.append(node) 2615 assert idx in edges 2616 for user in edges[idx].outs: 2617 e = edges[user] 2618 assert e.ins > 0 2619 e.ins -= 1 2620 if e.ins == 0: 2621 add_candidate(user) 2622 edges[idx].outs.clear() 2623 2624 return ret 2625 2626 sorted_nodes = sort_nodes(graph.nodes) 2627 assert len(sorted_nodes) == len(graph.nodes) 2628 2629 # Stage 2: Rename nodes. 2630 name_table: Dict[str, str] = {} 2631 2632 def rename_def(a): 2633 def _rename(arg_name, values): 2634 new_name = f"_{len(name_table)}" 2635 assert arg_name not in name_table 2636 name_table[arg_name] = new_name 2637 assert arg_name in values 2638 values[new_name] = values.pop(arg_name) 2639 return new_name 2640 2641 if a is None: 2642 return 2643 if isinstance(a, TensorArgument): 2644 a.name = _rename(a.name, graph.tensor_values) 2645 elif isinstance(a, SymIntArgument): 2646 if a.type == "as_name": 2647 a.as_name = _rename(a.as_name, graph.sym_int_values) 2648 elif isinstance(a, SymBoolArgument): 2649 if a.type == "as_name": 2650 a.as_name = _rename(a.as_name, graph.sym_bool_values) 2651 else: 2652 raise AssertionError(f"Unknown argument type: {a}") 2653 2654 def replace_use(a): 2655 if a is None: 2656 return 2657 if isinstance(a, TensorArgument): 2658 a.name = name_table.get(a.name, a.name) 2659 elif isinstance(a, SymIntArgument): 2660 if a.type == "as_name": 2661 a.as_name = name_table.get(a.as_name, a.as_name) 2662 elif isinstance(a, SymBoolArgument): 2663 if a.type == "as_name": 2664 a.as_name = name_table.get(a.as_name, a.as_name) 2665 elif isinstance(a, OptionalTensorArgument): 2666 if a.type == "as_tensor": 2667 a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name) 2668 else: 2669 raise AssertionError(f"Unknown argument type: {a}") 2670 2671 for i in sorted_inputs: 2672 for_args(rename_def, i) 2673 2674 for n in sorted_nodes: 2675 for o in n.outputs: 2676 for_args(rename_def, o) 2677 2678 for n in sorted_nodes: 2679 for i in n.inputs: 2680 for_args(replace_use, i.arg) 2681 2682 for o in sorted_outputs: 2683 for_args(replace_use, o) 2684 2685 # Stage 3: Remove unstable fields. 2686 for n in sorted_nodes: 2687 n.metadata.clear() 2688 2689 # Stage 4: Aggregate values. 2690 sorted_tensor_values = dict( 2691 sorted(graph.tensor_values.items(), key=operator.itemgetter(0)) 2692 ) 2693 sorted_sym_int_values = dict( 2694 sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) 2695 ) 2696 sorted_sym_bool_values = dict( 2697 sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) 2698 ) 2699 2700 # Stage 5: Recurse in subgraphs. 2701 counter = 0 2702 for node in sorted_nodes: 2703 for i in node.inputs: 2704 a = i.arg 2705 if a.type == "as_graph": 2706 a.as_graph.graph = _canonicalize_graph( 2707 a.as_graph.graph.inputs, a.as_graph.graph.outputs, a.as_graph.graph 2708 ) 2709 a.as_graph.name = f"_g{counter}" 2710 counter += 1 2711 2712 graph = Graph( 2713 inputs=sorted_inputs, 2714 outputs=sorted_outputs, 2715 nodes=sorted_nodes, 2716 tensor_values=sorted_tensor_values, 2717 sym_int_values=sorted_sym_int_values, 2718 sym_bool_values=sorted_sym_bool_values, 2719 is_single_tensor_return=graph.is_single_tensor_return, 2720 ) 2721 return graph, name_table 2722 2723 2724def canonicalize(ep: ExportedProgram) -> ExportedProgram: 2725 """ 2726 Normalize a serialized ExportedProgram, so that different eager program which 2727 shares the same semantics can get a single representation on disk. 2728 2729 This function canonicalizes an ExportedProgram by: 2730 2731 1. Sorting nodes in topological order. 2732 2. Rename nodes to have unique names. 2733 3. Remove unstable fields. 2734 4. Aggregate the above program fields. 2735 5. Recurse in subgraphs. 2736 2737 Args: 2738 ep (ExportedProgram): The ExportedProgram to canonicalize. 2739 2740 Returns: 2741 ExportedProgram: The canonicalized exported program. 2742 """ 2743 ep = copy.deepcopy(ep) 2744 2745 opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) 2746 range_constraints = dict( 2747 sorted(ep.range_constraints.items(), key=operator.itemgetter(0)) 2748 ) 2749 module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) 2750 signature = ep.graph_module.signature 2751 graph = ep.graph_module.graph 2752 2753 assert len(graph.inputs) == len(signature.input_specs) 2754 assert len(graph.outputs) == len(signature.output_specs) 2755 2756 def rank_input(inp) -> Tuple[int, Optional[str], int]: 2757 idx, (arg, spec) = inp 2758 assert isinstance(spec, InputSpec) 2759 if spec.type == "user_input": 2760 return 5, None, idx 2761 elif spec.type == "parameter": 2762 return 1, spec.parameter.parameter_name, idx 2763 elif spec.type == "buffer": 2764 return 2, spec.buffer.buffer_name, idx 2765 elif spec.type == "tensor_constant": 2766 return 3, spec.tensor_constant.tensor_constant_name, idx 2767 elif spec.type == "custom_obj": 2768 return 4, spec.custom_obj.custom_obj_name, idx 2769 elif spec.type == "token": 2770 return 0, None, idx 2771 elif spec.type == "constant_input": 2772 return 6, spec.constant_input.name, idx 2773 else: 2774 raise AssertionError(f"Unknown input type: {spec}") 2775 2776 def rank_output(out) -> Tuple[int, Optional[str], int]: 2777 idx, (arg, spec) = out 2778 assert isinstance(spec, OutputSpec) 2779 if spec.type == "user_output": 2780 return 3, None, idx 2781 elif spec.type == "loss_output": 2782 return 3, None, idx 2783 elif spec.type == "buffer_mutation": 2784 return 1, spec.buffer_mutation.buffer_name, idx 2785 elif spec.type == "gradient_to_parameter": 2786 return 4, spec.gradient_to_parameter.parameter_name, idx 2787 elif spec.type == "gradient_to_user_input": 2788 return 5, None, idx 2789 elif spec.type == "user_input_mutation": 2790 return 2, None, idx 2791 elif spec.type == "token": 2792 return 0, None, idx 2793 else: 2794 raise AssertionError(f"Unknown output type: {spec}") 2795 2796 sorted_ins = sorted( 2797 enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input 2798 ) 2799 sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] 2800 2801 sorted_outs = sorted( 2802 enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output 2803 ) 2804 sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] 2805 2806 sorted_graph, replace_table = _canonicalize_graph( 2807 sorted_inputs, sorted_outputs, graph 2808 ) 2809 2810 def replace_input(inp): 2811 assert isinstance(spec, InputSpec) 2812 if spec.type == "user_input": 2813 arg = spec.user_input.arg 2814 if arg.type == "as_tensor": 2815 t = arg.as_tensor 2816 t.name = replace_table[t.name] 2817 elif arg.type == "as_sym_int": 2818 s = arg.as_sym_int 2819 if s.type == "as_name": 2820 s.as_name = replace_table[s.as_name] 2821 elif s.type == "as_int": 2822 pass 2823 else: 2824 raise AssertionError(f"Unknown sym_int type: {s}") 2825 elif arg.type in ( 2826 "as_none", 2827 "as_bool", 2828 "as_int", 2829 "as_float", 2830 "as_string", 2831 "as_custom_obj", 2832 ): 2833 return 2834 else: 2835 raise AssertionError(f"Unknown input type: {arg}") 2836 elif spec.type == "parameter": 2837 t = spec.parameter.arg 2838 t.name = replace_table[t.name] 2839 elif spec.type == "buffer": 2840 t = spec.buffer.arg 2841 t.name = replace_table[t.name] 2842 elif spec.type == "tensor_constant": 2843 t = spec.tensor_constant.arg 2844 t.name = replace_table[t.name] 2845 elif spec.type == "custom_obj": 2846 return 2847 elif spec.type == "token": 2848 tok = spec.token.arg 2849 tok.name = replace_table[tok.name] 2850 elif spec.type == "constant_input": 2851 return 2852 else: 2853 raise AssertionError(f"Unknown input type: {spec}") 2854 2855 def replace_output(out): 2856 assert isinstance(spec, OutputSpec) 2857 if spec.type == "user_output": 2858 arg = spec.user_output.arg 2859 if arg.type == "as_tensor": 2860 t = arg.as_tensor 2861 t.name = replace_table[t.name] 2862 elif arg.type == "as_sym_int": 2863 s = arg.as_sym_int 2864 if s.type == "as_name": 2865 s.as_name = replace_table[s.as_name] 2866 elif s.type == "as_int": 2867 pass 2868 else: 2869 raise AssertionError(f"Unknown sym_int type: {s}") 2870 elif arg.type in ("as_none", "as_int", "as_float", "as_string"): 2871 return 2872 else: 2873 raise AssertionError(f"Unknown input type: {arg}") 2874 elif spec.type == "loss_output": 2875 t = spec.loss_output.arg 2876 t.name = replace_table[t.name] 2877 elif spec.type == "buffer_mutation": 2878 t = spec.buffer_mutation.arg 2879 t.name = replace_table[t.name] 2880 elif spec.type == "gradient_to_parameter": 2881 t = spec.gradient_to_parameter.arg 2882 t.name = replace_table[t.name] 2883 elif spec.type == "gradient_to_user_input": 2884 g = spec.gradient_to_user_input 2885 g.arg.name = replace_table[g.arg.name] 2886 g.user_input_name = replace_table[g.user_input_name] 2887 elif spec.type == "user_input_mutation": 2888 u = spec.user_input_mutation 2889 u.arg.name = replace_table[u.arg.name] 2890 u.user_input_name = replace_table[u.user_input_name] 2891 elif spec.type == "token": 2892 tok = spec.token.arg 2893 tok.name = replace_table[tok.name] 2894 else: 2895 raise AssertionError(f"Unknown output type: {spec}") 2896 2897 for spec in input_specs: 2898 replace_input(spec) 2899 2900 for spec in output_specs: 2901 replace_output(spec) 2902 2903 return ExportedProgram( 2904 graph_module=GraphModule( 2905 graph=sorted_graph, 2906 signature=GraphSignature( 2907 input_specs=list(input_specs), 2908 output_specs=list(output_specs), 2909 ), 2910 module_call_graph=module_call_graph, 2911 ), 2912 opset_version=opset_version, 2913 range_constraints=range_constraints, 2914 schema_version=ep.schema_version, 2915 dialect=ep.dialect, 2916 ) 2917