1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import collections.abc 5import dataclasses 6import inspect 7import logging 8import types 9import typing 10from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union 11 12import onnx 13 14import onnxscript 15from onnxscript import ir 16 17 18logger = logging.getLogger(__name__) 19 20 21# A special value to indicate that the default value is not specified 22class _Empty: 23 def __repr__(self): 24 return "_EMPTY_DEFAULT" 25 26 27_EMPTY_DEFAULT = _Empty() 28 29# Map from python type to corresponding ONNX AttributeProto type 30_PY_TYPE_TO_ATTR_TYPE = { 31 float: ir.AttributeType.FLOAT, 32 int: ir.AttributeType.INT, 33 str: ir.AttributeType.STRING, 34 bool: ir.AttributeType.INT, 35 ir.Tensor: ir.AttributeType.TENSOR, 36 ir.TensorProtocol: ir.AttributeType.TENSOR, 37 ir.Graph: ir.AttributeType.GRAPH, 38 ir.GraphProtocol: ir.AttributeType.GRAPH, 39} 40 41# Map from python type to corresponding ONNX AttributeProto type, 42# for repeated (i.e., list of) values 43_LIST_TYPE_TO_ATTR_TYPE = { 44 float: ir.AttributeType.FLOATS, 45 int: ir.AttributeType.INTS, 46 str: ir.AttributeType.STRINGS, 47 bool: ir.AttributeType.INTS, 48 ir.Tensor: ir.AttributeType.TENSORS, 49 ir.TensorProtocol: ir.AttributeType.TENSORS, 50 ir.Graph: ir.AttributeType.GRAPHS, 51 ir.GraphProtocol: ir.AttributeType.GRAPHS, 52} 53 54_ALL_VALUE_TYPES = ( 55 {ir.TensorType(dtype) for dtype in ir.DataType} 56 | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} 57 | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} 58) 59 60# TypeAnnotationValue represents the (value of) valid type-annotations recognized 61# by ONNX Script. Currently, it supports 62# - float, int, str (primitive attribute types) 63# - Sequence[float], Sequence[int], Sequence[str] (attribute types) 64# - Tensor types 65# - Sequence[Tensor] types 66# - Union of above 2 67# - TypeVars with above bounds 68# - Above types with annotation attached 69TypeAnnotationValue = Any 70 71 72@dataclasses.dataclass(frozen=True) 73class TypeConstraintParam: 74 """Type constraint for a parameter. 75 76 Attributes: 77 name: Name of the parameter. E.g. "TFloat" 78 allowed_types: Allowed types for the parameter. 79 """ 80 81 name: str 82 allowed_types: set[ir.TypeProtocol] 83 description: str = "" 84 85 def __hash__(self) -> int: 86 return hash((self.name, tuple(self.allowed_types))) 87 88 def __str__(self) -> str: 89 allowed_types_str = " | ".join(str(t) for t in self.allowed_types) 90 return f"{self.name}={allowed_types_str}" 91 92 @classmethod 93 def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: 94 return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) 95 96 @classmethod 97 def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: 98 return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] 99 100 101@dataclasses.dataclass(frozen=True) 102class Parameter: 103 """A formal parameter of an operator.""" 104 105 name: str 106 type_constraint: TypeConstraintParam 107 required: bool 108 variadic: bool 109 default: Any = _EMPTY_DEFAULT 110 # TODO: Add other properties too 111 112 def __str__(self) -> str: 113 type_str = self.type_constraint.name 114 if self.has_default(): 115 return f"{self.name}: {type_str} = {self.default}" 116 return f"{self.name}: {type_str}" 117 118 def has_default(self) -> bool: 119 return self.default is not _EMPTY_DEFAULT 120 121 122@dataclasses.dataclass(frozen=True) 123class AttributeParameter: 124 """A parameter in the function signature that represents an ONNX attribute.""" 125 126 name: str 127 type: ir.AttributeType 128 required: bool 129 default: ir.Attr | None = None 130 131 def __str__(self) -> str: 132 type_str = self.type.name 133 if self.has_default(): 134 return f"{self.name}: {type_str} = {self.default}" 135 return f"{self.name}: {type_str}" 136 137 def has_default(self) -> bool: 138 return self.default is not None 139 140 141def _get_type_from_str( 142 type_str: str, 143) -> ir.TensorType | ir.SequenceType | ir.OptionalType: 144 """Converter a type_str from ONNX Opschema to ir.TypeProtocol. 145 146 A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". 147 """ 148 149 # TODO: Upstream this to IR 150 151 # Split the type_str a sequence types and dtypes 152 # 1. Remove the ending ")" 153 striped = type_str.rstrip(")") 154 # 2. Split the type_str by "(" 155 type_parts = striped.split("(") 156 157 # Convert the dtype to ir.DataType 158 dtype = ir.DataType[type_parts[-1].upper()] 159 160 # Create a place holder type first 161 type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) 162 163 # Construct the type 164 for type_part in reversed(type_parts[:-1]): 165 if type_part == "tensor": 166 type_ = ir.TensorType(dtype) 167 elif type_part == "seq": 168 type_ = ir.SequenceType(type_) 169 elif type_part == "optional": 170 type_ = ir.OptionalType(type_) 171 else: 172 raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") 173 return type_ # type: ignore[return-value] 174 175 176def _convert_formal_parameter( 177 param: onnx.defs.OpSchema.FormalParameter, 178 type_constraints: Mapping[str, TypeConstraintParam], 179) -> Parameter: 180 """Convert a formal parameter from ONNX Opschema to Parameter.""" 181 if param.type_str in type_constraints: 182 type_constraint = type_constraints[param.type_str] 183 else: 184 # param.type_str can be a plain type like 'int64'. 185 type_constraint = TypeConstraintParam( 186 name=param.name, 187 allowed_types={_get_type_from_str(param.type_str)}, 188 ) 189 return Parameter( 190 name=param.name, 191 type_constraint=type_constraint, 192 required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, 193 variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, 194 ) 195 196 197def _is_optional(type_: type) -> bool: 198 """Returns whether a type_ is an Optional.""" 199 origin_type = typing.get_origin(type_) 200 if origin_type is Union and type(None) in typing.get_args(type_): 201 # Python < 3.10 202 return True 203 if origin_type is Optional: 204 # Python >= 3.10 205 return True 206 if ( 207 hasattr(types, "UnionType") 208 and origin_type is types.UnionType 209 and type(None) in typing.get_args(type_) 210 ): 211 # Python >= 3.10 212 return True 213 return False 214 215 216def _get_attr_type(type_: type) -> ir.AttributeType: 217 """Obtain the type of the attribute from a Python class.""" 218 try: 219 if type_ in _PY_TYPE_TO_ATTR_TYPE: 220 return _PY_TYPE_TO_ATTR_TYPE[type_] 221 origin_type = typing.get_origin(type_) 222 if origin_type is None: 223 return ir.AttributeType.UNDEFINED 224 if origin_type in ( 225 collections.abc.Sequence, 226 Sequence, 227 typing.List, 228 list, 229 typing.Tuple, 230 tuple, 231 ): 232 inner_type = typing.get_args(type_)[0] 233 if inner_type in _LIST_TYPE_TO_ATTR_TYPE: 234 return _LIST_TYPE_TO_ATTR_TYPE[inner_type] 235 except TypeError: 236 logger.warning("TypeError when checking %s.", type_, exc_info=True) 237 return ir.AttributeType.UNDEFINED 238 239 240def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: 241 """Returns the name of the type constraint for a given type annotation. 242 243 Args: 244 type_: A Python type. 245 246 Returns: 247 The name of the type constraint if it is a TypeVar. 248 - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. 249 """ 250 if isinstance(type_, TypeVar): 251 return type_.__name__ 252 if _is_optional(type_): 253 subtypes = typing.get_args(type_) 254 for subtype in subtypes: 255 if subtype is type(None): 256 continue 257 type_param_name = _get_type_constraint_name(subtype) 258 return type_param_name if type_param_name else None 259 origin_type = typing.get_origin(type_) 260 if isinstance(origin_type, type) and issubclass(origin_type, Sequence): 261 subtypes = typing.get_args(type_) 262 type_param_name = _get_type_constraint_name(subtypes[0]) 263 return f"Sequence_{type_param_name}" if type_param_name else None 264 return None 265 266 267def _get_allowed_types_from_type_annotation( 268 type_: TypeAnnotationValue, 269) -> set[ir.TypeProtocol]: 270 """Obtain the allowed types from a type annotation.""" 271 if type_ is onnxscript.onnx_types.TensorType: 272 # Any tensor type 273 return {ir.TensorType(dtype) for dtype in ir.DataType} 274 275 allowed_types: set[ir.TypeProtocol] 276 277 if isinstance(type_, TypeVar): 278 allowed_types = set() 279 if constraints := type_.__constraints__: 280 for constraint in constraints: 281 allowed_types.update( 282 _get_allowed_types_from_type_annotation(constraint) 283 ) 284 else: 285 bound = type_.__bound__ 286 if bound is None: 287 allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] 288 else: 289 allowed_types.update(_get_allowed_types_from_type_annotation(bound)) 290 return allowed_types 291 if hasattr(type_, "dtype"): 292 # A single tensor type like INT64, FLOAT, etc. 293 return {ir.TensorType(ir.DataType(type_.dtype))} 294 if _is_optional(type_): 295 allowed_types = set() 296 subtypes = typing.get_args(type_) 297 for subtype in subtypes: 298 if subtype is type(None): 299 continue 300 allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) 301 # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. 302 return allowed_types 303 304 origin_type = typing.get_origin(type_) 305 if origin_type is Union: 306 allowed_types = set() 307 subtypes = typing.get_args(type_) 308 for subtype in subtypes: 309 assert ( 310 subtype is not type(None) 311 ), "Union should not contain None type because it is handled by _is_optional." 312 allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) 313 return allowed_types 314 315 if isinstance(origin_type, type) and issubclass(origin_type, Sequence): 316 subtypes = typing.get_args(type_) 317 return { 318 ir.SequenceType(t) 319 for t in _get_allowed_types_from_type_annotation(subtypes[0]) 320 } 321 322 # Allow everything by default 323 return _ALL_VALUE_TYPES # type: ignore[return-value] 324 325 326@dataclasses.dataclass 327class OpSignature: 328 """Schema for an operator. 329 330 Attributes: 331 domain: Domain of the operator. E.g. "". 332 name: Name of the operator. E.g. "Add". 333 overload: Overload name of the operator. 334 params: Input parameters. When the op is an ONNX function definition, 335 the order is according to the function signature. This mean we can 336 interleave ONNX inputs and ONNX attributes in the list. 337 outputs: Output parameters. 338 """ 339 340 domain: str 341 name: str 342 overload: str 343 params: Sequence[Parameter | AttributeParameter] 344 outputs: Sequence[Parameter] 345 params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( 346 init=False, repr=False 347 ) 348 349 def __post_init__(self): 350 self.params_map = {param.name: param for param in self.params} 351 352 def get(self, name: str) -> Parameter | AttributeParameter: 353 return self.params_map[name] 354 355 def __contains__(self, name: str) -> bool: 356 return name in self.params_map 357 358 def __iter__(self) -> Iterator[Parameter | AttributeParameter]: 359 return iter(self.params) 360 361 def __str__(self) -> str: 362 domain = self.domain or "''" 363 # TODO: Double check the separator for overload 364 overload = f"::{self.overload}" if self.overload else "" 365 params = ", ".join(str(param) for param in self.params) 366 outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) 367 type_constraints = {} 368 for param in self.params: 369 if isinstance(param, Parameter): 370 type_constraints[param.type_constraint.name] = param.type_constraint 371 for param in self.outputs: 372 type_constraints[param.type_constraint.name] = param.type_constraint 373 type_constraints_str = ", ".join( 374 str(type_constraint) for type_constraint in type_constraints.values() 375 ) 376 return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" 377 378 @classmethod 379 def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature: 380 """Produce an OpSignature from an ONNX Opschema.""" 381 type_constraints = { 382 constraint.type_param_str: TypeConstraintParam( 383 name=constraint.type_param_str, 384 allowed_types={ 385 _get_type_from_str(type_str) 386 for type_str in constraint.allowed_type_strs 387 }, 388 description=constraint.description, 389 ) 390 for constraint in opschema.type_constraints 391 } 392 393 params = [ 394 _convert_formal_parameter(param, type_constraints) 395 for param in opschema.inputs 396 ] 397 398 for param in opschema.attributes.values(): 399 default_attr = ( 400 ir.serde.deserialize_attribute(param.default_value) 401 if param.default_value is not None 402 else None 403 ) 404 if default_attr is not None: 405 # Set the name of the default attribute because it may have a different name from the parameter 406 default_attr.name = param.name 407 params.append( 408 AttributeParameter( 409 name=param.name, 410 type=ir.AttributeType(param.type), # type: ignore[arg-type] 411 required=param.required, 412 default=default_attr, # type: ignore[arg-type] 413 ) 414 ) 415 416 outputs = [ 417 _convert_formal_parameter(param, type_constraints) 418 for param in opschema.outputs 419 ] 420 421 return cls( 422 domain=opschema.domain, 423 name=opschema.name, 424 overload="", 425 params=params, 426 outputs=outputs, 427 ) 428 429 @classmethod 430 def from_function( 431 cls, func, domain: str, name: str | None = None, overload: str = "" 432 ) -> OpSignature: 433 """Produce an OpSignature from a function using type annotation.""" 434 435 py_signature = inspect.signature(func) 436 # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases 437 # https://github.com/python/cpython/issues/102405 438 type_hints = typing.get_type_hints(func) 439 440 params: list[Parameter | AttributeParameter] = [] 441 # Create a mapping from type to a unique name 442 type_constraints: dict[str, TypeConstraintParam] = {} 443 444 for param in py_signature.parameters.values(): 445 if param.name not in type_hints: 446 logger.warning( 447 "Missing annotation for parameter '%s' from %s. Treating as an Input.", 448 param.name, 449 py_signature, 450 ) 451 type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") 452 type_constraints[param.name] = type_constraint 453 params.append( 454 Parameter( 455 name=param.name, 456 type_constraint=type_constraint, 457 required=param.default is inspect.Parameter.empty, 458 # TODO: Handle variadic 459 variadic=False, 460 default=param.default 461 if param.default is not inspect.Parameter.empty 462 else _EMPTY_DEFAULT, 463 ) 464 ) 465 else: 466 type_ = type_hints[param.name] 467 if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: 468 # Construct the default attribute 469 if param.default is not inspect.Parameter.empty: 470 # TODO: Use ir_convenience instead to handle int as float 471 default = ir.Attr(param.name, attr_type, param.default) 472 else: 473 default = None 474 params.append( 475 AttributeParameter( 476 name=param.name, 477 type=attr_type, 478 required=param.default is inspect.Parameter.empty, 479 default=default, 480 ) 481 ) 482 else: 483 # Obtain the type constraint from the type annotation 484 485 # 1. Get a type constraint name from the type annotation 486 # If the type annotation is a TypeVar or Optional[TypeVar], get its name 487 # Otherwise, name it T_{param.name} 488 type_constraint_name = _get_type_constraint_name(type_) 489 if type_constraint_name is None: 490 type_constraint_name = f"T_{param.name}" 491 492 # 2. If the type constraint param is already initialized, use it 493 if type_constraint_name in type_constraints: 494 type_constraint = type_constraints[type_constraint_name] 495 else: 496 # 3. Otherwise, create a new TypeConstraintParam 497 type_constraint = TypeConstraintParam( 498 name=type_constraint_name, 499 allowed_types=_get_allowed_types_from_type_annotation( 500 type_ 501 ), 502 ) 503 type_constraints[type_constraint_name] = type_constraint 504 # 4. Create Parameter 505 params.append( 506 Parameter( 507 name=param.name, 508 type_constraint=type_constraint, 509 required=param.default is inspect.Parameter.empty, 510 # TODO: Handle variadic 511 variadic=False, 512 default=param.default 513 if param.default is not inspect.Parameter.empty 514 else _EMPTY_DEFAULT, 515 ) 516 ) 517 518 return_type = type_hints.get("return") 519 520 outputs = [] 521 if return_type is None: 522 # No returns 523 pass 524 else: 525 if typing.get_origin(return_type) is tuple: 526 # Multiple returns 527 return_types = typing.get_args(return_type) 528 else: 529 return_types = [return_type] # type: ignore[assignment] 530 531 for i, return_type_i in enumerate(return_types): 532 if ( 533 return_param_name := _get_type_constraint_name(return_type_i) 534 ) in type_constraints: 535 type_constraint = type_constraints[return_param_name] 536 else: 537 return_param_name = f"TReturn{i}" 538 type_constraint = TypeConstraintParam( 539 name=return_param_name, 540 allowed_types=_get_allowed_types_from_type_annotation( 541 return_type_i 542 ), 543 ) 544 type_constraints[return_param_name] = type_constraint 545 outputs.append( 546 Parameter( 547 name=return_param_name, 548 type_constraint=type_constraint, 549 required=True, 550 variadic=False, 551 default=_EMPTY_DEFAULT, 552 ) 553 ) 554 555 return cls( 556 domain=domain, 557 name=name or func.__name__, 558 overload=overload, 559 params=params, 560 outputs=outputs, 561 ) 562