xref: /aosp_15_r20/external/pytorch/torch/_library/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import inspect
4import sys
5from typing import Any, Callable, Dict, Iterable, Tuple, Union
6
7import torch
8from torch import _C, _utils_internal
9from torch._ops import OpOverload
10
11
12@dataclasses.dataclass
13class Kernel:
14    """Models a (function, source location)"""
15
16    func: Callable
17    source: str
18
19    def __call__(self, *args, **kwargs):
20        return self.func(*args, **kwargs)
21
22
23class RegistrationHandle:
24    """Does something when someone calls .destroy() on it"""
25
26    def __init__(self, on_destroy: Callable):
27        self._on_destroy = on_destroy
28
29    def destroy(self) -> None:
30        self._on_destroy()
31
32
33def get_source(stacklevel: int) -> str:
34    """Get a string that represents the caller.
35
36    Example: "/path/to/foo.py:42"
37
38    Use stacklevel=1 to get the caller's source
39    Use stacklevel=2 to get the caller's caller's source
40    etc.
41    """
42    frame = inspect.getframeinfo(sys._getframe(stacklevel))
43    source = f"{frame.filename}:{frame.lineno}"
44    return source
45
46
47def parse_namespace(qualname: str) -> Tuple[str, str]:
48    splits = qualname.split("::")
49    if len(splits) != 2:
50        raise ValueError(
51            f"Expected `qualname` to be of the form "
52            f'"namespace::name", but got {qualname}. '
53            f"The qualname passed to the torch.library APIs must consist "
54            f"of a namespace and a name, e.g. aten::sin"
55        )
56    return splits[0], splits[1]
57
58
59def lookup_op(qualname: str) -> OpOverload:
60    namespace, name = parse_namespace(qualname)
61    if "." in name:
62        name, overload = name.split(".")
63    else:
64        overload = "default"
65    ns = getattr(torch.ops, namespace)
66    packet = getattr(ns, name)
67    return getattr(packet, overload)
68
69
70def is_builtin(op: OpOverload) -> bool:
71    assert isinstance(op, OpOverload)
72    return op.namespace in {"aten", "prim", "prims"}
73
74
75def is_functional_schema(schema: Any) -> bool:
76    """Check if the schema is functional.
77
78    An operator is functional if:
79    - it does not mutate any of its inputs
80    - it does not return a view on any of its inputs
81    - it has at least one return
82    """
83
84    def is_functional(schema):
85        if schema.is_mutable:
86            return False
87        rets = schema.returns
88        is_non_mutating_view = len(rets) > 0 and any(
89            r.alias_info is not None and not r.alias_info.is_write for r in rets
90        )
91        if is_non_mutating_view:
92            return False
93        if not schema.returns:
94            return False
95        return True
96
97    if isinstance(schema, torch._C.FunctionSchema):
98        return is_functional(schema)
99
100    # Lazy import because not all PyTorch builds have torchgen
101    from torchgen.model import FunctionSchema
102
103    if isinstance(schema, str):
104        schema = FunctionSchema.parse(schema)
105    assert isinstance(schema, FunctionSchema)
106    return is_functional(schema)
107
108
109# should be torch._C.JitType but that annotation is busted
110def is_tensorlist_like_type(typ: Any) -> bool:
111    return (
112        typ == _C.ListType(_C.TensorType.get())
113        or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
114        or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
115        or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
116    )
117
118
119# should be torch._C.JitType but that annotation is busted
120def is_tensor_like_type(typ: Any) -> bool:
121    return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
122
123
124def mutates_and_returns_first_arg(op: OpOverload):
125    """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
126
127    TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
128    but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
129    Figure this out.
130
131    Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
132    """
133    if op.namespace != "aten":
134        return False
135    schema = op._schema
136    if not len(schema.returns) == 1:
137        return False
138    if schema.returns[0].alias_info is None:
139        return False
140    alias_set = schema.returns[0].alias_info.after_set
141    if len(alias_set) != 1:
142        return False
143    loc = next(iter(alias_set))
144    if len(schema.arguments) < 1:
145        return False
146    first_arg = schema.arguments[0]
147    if first_arg.alias_info is None:
148        return False
149    if not first_arg.alias_info.is_write:
150        return False
151    alias_set = first_arg.alias_info.after_set
152    if len(alias_set) != 1:
153        return False
154    if loc != next(iter(alias_set)):
155        return False
156    for arg in schema.arguments[1:]:
157        if arg.alias_info is not None:
158            return False
159    return True
160
161
162def fill_defaults(schema, args, kwargs):
163    new_args = []
164    new_kwargs = {}
165    for i in range(len(schema.arguments)):
166        info = schema.arguments[i]
167        if info.kwarg_only:
168            if info.name in kwargs:
169                new_kwargs[info.name] = kwargs[info.name]
170            else:
171                new_kwargs[info.name] = info.default_value
172        else:
173            if i < len(args):
174                new_args.append(args[i])
175            else:
176                new_args.append(info.default_value)
177    return tuple(new_args), new_kwargs
178
179
180def zip_schema(
181    schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
182) -> Iterable[Tuple[_C.Argument, Any]]:
183    """zips schema.arguments and (args, kwargs) together.
184
185    Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
186    that is, kwargs must be keyword-only arguments and default values may be omitted.
187    """
188    assert len(schema.arguments) >= len(args) + len(kwargs)
189    for i in range(len(schema.arguments)):
190        info = schema.arguments[i]
191        if info.kwarg_only:
192            if info.name in kwargs:
193                yield info, kwargs[info.name]
194            continue
195        if i >= len(args):
196            # args that are equal to their default values are not populated
197            # if they are followed by args that are equal to their defaults.
198            # Skip these.
199            continue
200        yield info, args[i]
201    return
202
203
204def hop_schema_from_fx_node(node):
205    from torchgen.gen_schema_utils import FunctionSchemaGen
206
207    hop = node.target
208    if not isinstance(hop, torch._ops.HigherOrderOperator):
209        raise RuntimeError("fx_node's target must be a hop.")
210
211    def _collect_example_val(node):
212        meta_val = node.meta.get("val", None)
213        if meta_val is None:
214            assert node.op == "get_attr"
215            meta_val = getattr(node.graph.owning_module, node.target)
216        return meta_val
217
218    example_inputs = []
219    for arg in node.args:
220        if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
221            example_inputs.append(_collect_example_val(arg))
222        elif isinstance(
223            arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
224        ):
225            example_inputs.append([_collect_example_val(x) for x in arg])
226        else:
227            raise RuntimeError(f"Unsupported arg type {type(arg)}")
228
229    # Bound the arguments to make sure number of inputs are correct
230    bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
231        *example_inputs
232    )
233
234    # We treat example_output as a single value in return. This is to differentiate 1. return a single val
235    # vs 2. return a tuple with one element.
236    example_output = _collect_example_val(node)
237    return FunctionSchemaGen.from_example(
238        hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
239    )
240
241
242def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
243    assert isinstance(op, OpOverload)
244    if is_builtin(op):
245        # We control the built-ins. These may (in rare cases)
246        # do input metadata mutation (which we have banned on custom ops)
247        return False
248    schema = op._schema
249    # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
250    if not schema.is_mutable:
251        return False
252    if len(schema.returns) > 0:
253        return False
254    # If the op returns nothing, then it has a trivial fake impl.
255    return True
256
257
258def requires_set_python_module() -> bool:
259    """If an op was defined in C++ and extended from Python using the
260    torch.library APIs, returns if we require that there have been a
261    m.set_python_module("mylib.ops") call from C++ that associates
262    the C++ op with a python module.
263    """
264    return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
265
266
267def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
268    assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
269    overload_types = []
270    args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
271    for a in args_flattened:
272        # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
273        # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
274        # where in one case we only include tensors with the python key, and in another
275        # we include **all** tensors.
276        if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
277            torch._C.DispatchKey.Python
278        ):
279            overload_types.append(type(a))
280    # TODO: check that I got these args correct (in C++, we pass in "0000"??)
281
282    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
283
284
285def has_kwarg_only_args(schema: _C.FunctionSchema):
286    return any(a.kwarg_only for a in schema.arguments)
287
288
289def has_kwarg_only_tensors(schema: _C.FunctionSchema):
290    for a in schema.arguments:
291        if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
292            continue
293        if not a.kwarg_only:
294            continue
295        return True
296    return False
297
298
299def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
300    """
301    Given a schema, returns True if the schema has a Tensor arg.
302    A Tensor arg is any arg with a type annotation that might involve Tensor.
303    """
304    return any(
305        (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
306        for a in schema.arguments
307    )
308
309
310def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
311    """
312    Given a schema, returns the id of the `device: torch.device` argument.
313    If it does not exist, returns None.
314    """
315    for index, arg in enumerate(schema.arguments):
316        if arg.type is _C.DeviceObjType.get() and arg.name == "device":
317            return index
318    return None
319