xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/exporter/_schemas.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import collections.abc
5import dataclasses
6import inspect
7import logging
8import types
9import typing
10from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union
11
12import onnx
13
14import onnxscript
15from onnxscript import ir
16
17
18logger = logging.getLogger(__name__)
19
20
21# A special value to indicate that the default value is not specified
22class _Empty:
23    def __repr__(self):
24        return "_EMPTY_DEFAULT"
25
26
27_EMPTY_DEFAULT = _Empty()
28
29# Map from python type to corresponding ONNX AttributeProto type
30_PY_TYPE_TO_ATTR_TYPE = {
31    float: ir.AttributeType.FLOAT,
32    int: ir.AttributeType.INT,
33    str: ir.AttributeType.STRING,
34    bool: ir.AttributeType.INT,
35    ir.Tensor: ir.AttributeType.TENSOR,
36    ir.TensorProtocol: ir.AttributeType.TENSOR,
37    ir.Graph: ir.AttributeType.GRAPH,
38    ir.GraphProtocol: ir.AttributeType.GRAPH,
39}
40
41# Map from python type to corresponding ONNX AttributeProto type,
42# for repeated (i.e., list of) values
43_LIST_TYPE_TO_ATTR_TYPE = {
44    float: ir.AttributeType.FLOATS,
45    int: ir.AttributeType.INTS,
46    str: ir.AttributeType.STRINGS,
47    bool: ir.AttributeType.INTS,
48    ir.Tensor: ir.AttributeType.TENSORS,
49    ir.TensorProtocol: ir.AttributeType.TENSORS,
50    ir.Graph: ir.AttributeType.GRAPHS,
51    ir.GraphProtocol: ir.AttributeType.GRAPHS,
52}
53
54_ALL_VALUE_TYPES = (
55    {ir.TensorType(dtype) for dtype in ir.DataType}
56    | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}
57    | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType}
58)
59
60# TypeAnnotationValue represents the (value of) valid type-annotations recognized
61# by ONNX Script. Currently, it supports
62# - float, int, str (primitive attribute types)
63# - Sequence[float], Sequence[int], Sequence[str] (attribute types)
64# - Tensor types
65# - Sequence[Tensor] types
66# - Union of above 2
67# - TypeVars with above bounds
68# - Above types with annotation attached
69TypeAnnotationValue = Any
70
71
72@dataclasses.dataclass(frozen=True)
73class TypeConstraintParam:
74    """Type constraint for a parameter.
75
76    Attributes:
77        name: Name of the parameter. E.g. "TFloat"
78        allowed_types: Allowed types for the parameter.
79    """
80
81    name: str
82    allowed_types: set[ir.TypeProtocol]
83    description: str = ""
84
85    def __hash__(self) -> int:
86        return hash((self.name, tuple(self.allowed_types)))
87
88    def __str__(self) -> str:
89        allowed_types_str = " | ".join(str(t) for t in self.allowed_types)
90        return f"{self.name}={allowed_types_str}"
91
92    @classmethod
93    def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
94        return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description)
95
96    @classmethod
97    def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
98        return cls(name, _ALL_VALUE_TYPES, description)  # type: ignore[arg-type]
99
100
101@dataclasses.dataclass(frozen=True)
102class Parameter:
103    """A formal parameter of an operator."""
104
105    name: str
106    type_constraint: TypeConstraintParam
107    required: bool
108    variadic: bool
109    default: Any = _EMPTY_DEFAULT
110    # TODO: Add other properties too
111
112    def __str__(self) -> str:
113        type_str = self.type_constraint.name
114        if self.has_default():
115            return f"{self.name}: {type_str} = {self.default}"
116        return f"{self.name}: {type_str}"
117
118    def has_default(self) -> bool:
119        return self.default is not _EMPTY_DEFAULT
120
121
122@dataclasses.dataclass(frozen=True)
123class AttributeParameter:
124    """A parameter in the function signature that represents an ONNX attribute."""
125
126    name: str
127    type: ir.AttributeType
128    required: bool
129    default: ir.Attr | None = None
130
131    def __str__(self) -> str:
132        type_str = self.type.name
133        if self.has_default():
134            return f"{self.name}: {type_str} = {self.default}"
135        return f"{self.name}: {type_str}"
136
137    def has_default(self) -> bool:
138        return self.default is not None
139
140
141def _get_type_from_str(
142    type_str: str,
143) -> ir.TensorType | ir.SequenceType | ir.OptionalType:
144    """Converter a type_str from ONNX Opschema to ir.TypeProtocol.
145
146    A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
147    """
148
149    # TODO: Upstream this to IR
150
151    # Split the type_str a sequence types and dtypes
152    # 1. Remove the ending ")"
153    striped = type_str.rstrip(")")
154    # 2. Split the type_str by "("
155    type_parts = striped.split("(")
156
157    # Convert the dtype to ir.DataType
158    dtype = ir.DataType[type_parts[-1].upper()]
159
160    # Create a place holder type first
161    type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED)
162
163    # Construct the type
164    for type_part in reversed(type_parts[:-1]):
165        if type_part == "tensor":
166            type_ = ir.TensorType(dtype)
167        elif type_part == "seq":
168            type_ = ir.SequenceType(type_)
169        elif type_part == "optional":
170            type_ = ir.OptionalType(type_)
171        else:
172            raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'")
173    return type_  # type: ignore[return-value]
174
175
176def _convert_formal_parameter(
177    param: onnx.defs.OpSchema.FormalParameter,
178    type_constraints: Mapping[str, TypeConstraintParam],
179) -> Parameter:
180    """Convert a formal parameter from ONNX Opschema to Parameter."""
181    if param.type_str in type_constraints:
182        type_constraint = type_constraints[param.type_str]
183    else:
184        # param.type_str can be a plain type like 'int64'.
185        type_constraint = TypeConstraintParam(
186            name=param.name,
187            allowed_types={_get_type_from_str(param.type_str)},
188        )
189    return Parameter(
190        name=param.name,
191        type_constraint=type_constraint,
192        required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional,
193        variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic,
194    )
195
196
197def _is_optional(type_: type) -> bool:
198    """Returns whether a type_ is an Optional."""
199    origin_type = typing.get_origin(type_)
200    if origin_type is Union and type(None) in typing.get_args(type_):
201        # Python < 3.10
202        return True
203    if origin_type is Optional:
204        # Python >= 3.10
205        return True
206    if (
207        hasattr(types, "UnionType")
208        and origin_type is types.UnionType
209        and type(None) in typing.get_args(type_)
210    ):
211        # Python >= 3.10
212        return True
213    return False
214
215
216def _get_attr_type(type_: type) -> ir.AttributeType:
217    """Obtain the type of the attribute from a Python class."""
218    try:
219        if type_ in _PY_TYPE_TO_ATTR_TYPE:
220            return _PY_TYPE_TO_ATTR_TYPE[type_]
221        origin_type = typing.get_origin(type_)
222        if origin_type is None:
223            return ir.AttributeType.UNDEFINED
224        if origin_type in (
225            collections.abc.Sequence,
226            Sequence,
227            typing.List,
228            list,
229            typing.Tuple,
230            tuple,
231        ):
232            inner_type = typing.get_args(type_)[0]
233            if inner_type in _LIST_TYPE_TO_ATTR_TYPE:
234                return _LIST_TYPE_TO_ATTR_TYPE[inner_type]
235    except TypeError:
236        logger.warning("TypeError when checking %s.", type_, exc_info=True)
237    return ir.AttributeType.UNDEFINED
238
239
240def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None:
241    """Returns the name of the type constraint for a given type annotation.
242
243    Args:
244        type_: A Python type.
245
246    Returns:
247        The name of the type constraint if it is a TypeVar.
248        - Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
249    """
250    if isinstance(type_, TypeVar):
251        return type_.__name__
252    if _is_optional(type_):
253        subtypes = typing.get_args(type_)
254        for subtype in subtypes:
255            if subtype is type(None):
256                continue
257            type_param_name = _get_type_constraint_name(subtype)
258            return type_param_name if type_param_name else None
259    origin_type = typing.get_origin(type_)
260    if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
261        subtypes = typing.get_args(type_)
262        type_param_name = _get_type_constraint_name(subtypes[0])
263        return f"Sequence_{type_param_name}" if type_param_name else None
264    return None
265
266
267def _get_allowed_types_from_type_annotation(
268    type_: TypeAnnotationValue,
269) -> set[ir.TypeProtocol]:
270    """Obtain the allowed types from a type annotation."""
271    if type_ is onnxscript.onnx_types.TensorType:
272        # Any tensor type
273        return {ir.TensorType(dtype) for dtype in ir.DataType}
274
275    allowed_types: set[ir.TypeProtocol]
276
277    if isinstance(type_, TypeVar):
278        allowed_types = set()
279        if constraints := type_.__constraints__:
280            for constraint in constraints:
281                allowed_types.update(
282                    _get_allowed_types_from_type_annotation(constraint)
283                )
284        else:
285            bound = type_.__bound__
286            if bound is None:
287                allowed_types = _ALL_VALUE_TYPES  # type: ignore[assignment]
288            else:
289                allowed_types.update(_get_allowed_types_from_type_annotation(bound))
290        return allowed_types
291    if hasattr(type_, "dtype"):
292        # A single tensor type like INT64, FLOAT, etc.
293        return {ir.TensorType(ir.DataType(type_.dtype))}
294    if _is_optional(type_):
295        allowed_types = set()
296        subtypes = typing.get_args(type_)
297        for subtype in subtypes:
298            if subtype is type(None):
299                continue
300            allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
301        # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful.
302        return allowed_types
303
304    origin_type = typing.get_origin(type_)
305    if origin_type is Union:
306        allowed_types = set()
307        subtypes = typing.get_args(type_)
308        for subtype in subtypes:
309            assert (
310                subtype is not type(None)
311            ), "Union should not contain None type because it is handled by _is_optional."
312            allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
313        return allowed_types
314
315    if isinstance(origin_type, type) and issubclass(origin_type, Sequence):
316        subtypes = typing.get_args(type_)
317        return {
318            ir.SequenceType(t)
319            for t in _get_allowed_types_from_type_annotation(subtypes[0])
320        }
321
322    # Allow everything by default
323    return _ALL_VALUE_TYPES  # type: ignore[return-value]
324
325
326@dataclasses.dataclass
327class OpSignature:
328    """Schema for an operator.
329
330    Attributes:
331        domain: Domain of the operator. E.g. "".
332        name: Name of the operator. E.g. "Add".
333        overload: Overload name of the operator.
334        params: Input parameters. When the op is an ONNX function definition,
335          the order is according to the function signature. This mean we can
336          interleave ONNX inputs and ONNX attributes in the list.
337        outputs: Output parameters.
338    """
339
340    domain: str
341    name: str
342    overload: str
343    params: Sequence[Parameter | AttributeParameter]
344    outputs: Sequence[Parameter]
345    params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field(
346        init=False, repr=False
347    )
348
349    def __post_init__(self):
350        self.params_map = {param.name: param for param in self.params}
351
352    def get(self, name: str) -> Parameter | AttributeParameter:
353        return self.params_map[name]
354
355    def __contains__(self, name: str) -> bool:
356        return name in self.params_map
357
358    def __iter__(self) -> Iterator[Parameter | AttributeParameter]:
359        return iter(self.params)
360
361    def __str__(self) -> str:
362        domain = self.domain or "''"
363        # TODO: Double check the separator for overload
364        overload = f"::{self.overload}" if self.overload else ""
365        params = ", ".join(str(param) for param in self.params)
366        outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs)
367        type_constraints = {}
368        for param in self.params:
369            if isinstance(param, Parameter):
370                type_constraints[param.type_constraint.name] = param.type_constraint
371        for param in self.outputs:
372            type_constraints[param.type_constraint.name] = param.type_constraint
373        type_constraints_str = ", ".join(
374            str(type_constraint) for type_constraint in type_constraints.values()
375        )
376        return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}"
377
378    @classmethod
379    def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature:
380        """Produce an OpSignature from an ONNX Opschema."""
381        type_constraints = {
382            constraint.type_param_str: TypeConstraintParam(
383                name=constraint.type_param_str,
384                allowed_types={
385                    _get_type_from_str(type_str)
386                    for type_str in constraint.allowed_type_strs
387                },
388                description=constraint.description,
389            )
390            for constraint in opschema.type_constraints
391        }
392
393        params = [
394            _convert_formal_parameter(param, type_constraints)
395            for param in opschema.inputs
396        ]
397
398        for param in opschema.attributes.values():
399            default_attr = (
400                ir.serde.deserialize_attribute(param.default_value)
401                if param.default_value is not None
402                else None
403            )
404            if default_attr is not None:
405                # Set the name of the default attribute because it may have a different name from the parameter
406                default_attr.name = param.name
407            params.append(
408                AttributeParameter(
409                    name=param.name,
410                    type=ir.AttributeType(param.type),  # type: ignore[arg-type]
411                    required=param.required,
412                    default=default_attr,  # type: ignore[arg-type]
413                )
414            )
415
416        outputs = [
417            _convert_formal_parameter(param, type_constraints)
418            for param in opschema.outputs
419        ]
420
421        return cls(
422            domain=opschema.domain,
423            name=opschema.name,
424            overload="",
425            params=params,
426            outputs=outputs,
427        )
428
429    @classmethod
430    def from_function(
431        cls, func, domain: str, name: str | None = None, overload: str = ""
432    ) -> OpSignature:
433        """Produce an OpSignature from a function using type annotation."""
434
435        py_signature = inspect.signature(func)
436        # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases
437        # https://github.com/python/cpython/issues/102405
438        type_hints = typing.get_type_hints(func)
439
440        params: list[Parameter | AttributeParameter] = []
441        # Create a mapping from type to a unique name
442        type_constraints: dict[str, TypeConstraintParam] = {}
443
444        for param in py_signature.parameters.values():
445            if param.name not in type_hints:
446                logger.warning(
447                    "Missing annotation for parameter '%s' from %s. Treating as an Input.",
448                    param.name,
449                    py_signature,
450                )
451                type_constraint = TypeConstraintParam.any_value(f"T_{param.name}")
452                type_constraints[param.name] = type_constraint
453                params.append(
454                    Parameter(
455                        name=param.name,
456                        type_constraint=type_constraint,
457                        required=param.default is inspect.Parameter.empty,
458                        # TODO: Handle variadic
459                        variadic=False,
460                        default=param.default
461                        if param.default is not inspect.Parameter.empty
462                        else _EMPTY_DEFAULT,
463                    )
464                )
465            else:
466                type_ = type_hints[param.name]
467                if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
468                    # Construct the default attribute
469                    if param.default is not inspect.Parameter.empty:
470                        # TODO: Use ir_convenience instead to handle int as float
471                        default = ir.Attr(param.name, attr_type, param.default)
472                    else:
473                        default = None
474                    params.append(
475                        AttributeParameter(
476                            name=param.name,
477                            type=attr_type,
478                            required=param.default is inspect.Parameter.empty,
479                            default=default,
480                        )
481                    )
482                else:
483                    # Obtain the type constraint from the type annotation
484
485                    # 1. Get a type constraint name from the type annotation
486                    # If the type annotation is a TypeVar or Optional[TypeVar], get its name
487                    # Otherwise, name it T_{param.name}
488                    type_constraint_name = _get_type_constraint_name(type_)
489                    if type_constraint_name is None:
490                        type_constraint_name = f"T_{param.name}"
491
492                    # 2. If the type constraint param is already initialized, use it
493                    if type_constraint_name in type_constraints:
494                        type_constraint = type_constraints[type_constraint_name]
495                    else:
496                        # 3. Otherwise, create a new TypeConstraintParam
497                        type_constraint = TypeConstraintParam(
498                            name=type_constraint_name,
499                            allowed_types=_get_allowed_types_from_type_annotation(
500                                type_
501                            ),
502                        )
503                        type_constraints[type_constraint_name] = type_constraint
504                    # 4. Create Parameter
505                    params.append(
506                        Parameter(
507                            name=param.name,
508                            type_constraint=type_constraint,
509                            required=param.default is inspect.Parameter.empty,
510                            # TODO: Handle variadic
511                            variadic=False,
512                            default=param.default
513                            if param.default is not inspect.Parameter.empty
514                            else _EMPTY_DEFAULT,
515                        )
516                    )
517
518        return_type = type_hints.get("return")
519
520        outputs = []
521        if return_type is None:
522            # No returns
523            pass
524        else:
525            if typing.get_origin(return_type) is tuple:
526                # Multiple returns
527                return_types = typing.get_args(return_type)
528            else:
529                return_types = [return_type]  # type: ignore[assignment]
530
531            for i, return_type_i in enumerate(return_types):
532                if (
533                    return_param_name := _get_type_constraint_name(return_type_i)
534                ) in type_constraints:
535                    type_constraint = type_constraints[return_param_name]
536                else:
537                    return_param_name = f"TReturn{i}"
538                    type_constraint = TypeConstraintParam(
539                        name=return_param_name,
540                        allowed_types=_get_allowed_types_from_type_annotation(
541                            return_type_i
542                        ),
543                    )
544                    type_constraints[return_param_name] = type_constraint
545                outputs.append(
546                    Parameter(
547                        name=return_param_name,
548                        type_constraint=type_constraint,
549                        required=True,
550                        variadic=False,
551                        default=_EMPTY_DEFAULT,
552                    )
553                )
554
555        return cls(
556            domain=domain,
557            name=name or func.__name__,
558            overload=overload,
559            params=params,
560            outputs=outputs,
561        )
562