xref: /aosp_15_r20/external/pytorch/torchgen/gen_schema_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Optional, Tuple, Union
2
3from torchgen.model import (
4    Annotation,
5    Argument,
6    Arguments,
7    BaseOperatorName,
8    BaseTy,
9    BaseType,
10    CustomClassType,
11    FunctionSchema,
12    ListType,
13    OperatorName,
14    Return,
15)
16
17
18# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
19# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
20# their schemas can vary for different instances of the same HOP.
21
22
23class TypeGen:
24    convert_to_base_ty = {
25        int: BaseTy.int,
26        float: BaseTy.float,
27        str: BaseTy.str,
28        bool: BaseTy.bool,
29    }
30
31    @staticmethod
32    def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
33        import torch
34
35        if isinstance(obj, torch.fx.GraphModule):
36            return BaseType(BaseTy.GraphModule)
37        elif isinstance(obj, torch.Tensor):
38            return BaseType(BaseTy.Tensor)
39        elif isinstance(obj, torch.SymInt):
40            return BaseType(BaseTy.SymInt)
41        elif isinstance(obj, torch.SymBool):
42            return BaseType(BaseTy.SymBool)
43        elif isinstance(obj, torch.ScriptObject):
44            return CustomClassType(obj._type().name())  # type: ignore[attr-defined]
45        elif isinstance(obj, (list, tuple)):
46            assert len(obj) > 0
47            all_base_tys = [TypeGen.from_example(x) for x in obj]
48            if len(set(all_base_tys)) > 1:
49                raise RuntimeError(
50                    f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
51                    "Consider unpacking the argument and give proper names to them if possible "
52                    "instead of using *args."
53                )
54            return ListType(all_base_tys[0], len(obj))
55        tp = type(obj)
56        if tp not in TypeGen.convert_to_base_ty:
57            raise RuntimeError(f"unsupported type {tp}")
58        return BaseType(TypeGen.convert_to_base_ty[tp])
59
60
61class ReturnGen:
62    @staticmethod
63    def from_example(
64        name: Optional[str], obj: Any, annotation: Optional[Annotation]
65    ) -> Return:
66        return Return(name, TypeGen.from_example(obj), annotation)
67
68
69class ArgumentGen:
70    @staticmethod
71    def from_example(
72        name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
73    ) -> Argument:
74        return Argument(
75            name, TypeGen.from_example(obj), default=default, annotation=annotation
76        )
77
78
79class FunctionSchemaGen:
80    @staticmethod
81    def from_example(
82        op_name: str,
83        example_inputs: Tuple[Tuple[str, Any], ...],
84        example_outputs: Tuple[Any, ...],
85    ) -> FunctionSchema:
86        args = []
87        for name, inp in example_inputs:
88            args.append(ArgumentGen.from_example(name, inp, None, None))
89        # ignore the annotations and other attributes for now, we could add more when needed.
90        arguments = Arguments(
91            tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
92        )
93        returns = tuple(
94            ReturnGen.from_example(None, out, None) for out in example_outputs
95        )
96        op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
97        return FunctionSchema(op_name, arguments, returns)
98