xref: /aosp_15_r20/external/executorch/exir/operator/manip.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9"""
10This module contains APIs to manipulate ops.
11"""
12from dataclasses import dataclass
13from typing import Callable, Dict
14
15import torch
16from executorch.exir.tensor import TensorSpec
17
18
19@dataclass
20class ScratchTensorMetadata:
21    dtype: torch.dtype
22    shape: torch.Size
23    layout: torch.layout = torch.strided
24    device: torch.device = torch.device("cpu")
25    is_sparse: bool = False
26
27
28ScratchCallableType = Callable[..., Dict[str, ScratchTensorMetadata]]
29
30
31def attach_get_scratch_metas_fn(
32    out_variant: torch._ops.OpOverload,
33) -> Callable[[ScratchCallableType], ScratchCallableType]:
34    """
35    Apply this decorator to the get_scratch_metas methods for `out_variant` op.
36    The decorator will do the job to attach the get_scratch_metas method
37    to the out variant op.
38
39    The get_scratch_metas method has the same signature as the out variant op.
40    There are 2 difference though:
41    - the Tensor input arguments are all replaced with TensorSpec
42    - the output is a dictionary of ScratchTensorMetadata
43    """
44
45    def to_tensor_spec(meta: ScratchTensorMetadata) -> TensorSpec:
46        return TensorSpec(
47            const=False,
48            requires_grad=False,
49            # fields copy from ScratchTensorMetadata
50            dtype=meta.dtype,
51            shape=meta.shape,
52            layout=meta.layout,
53            is_sparse=meta.is_sparse,
54        )
55
56    def adapt_return_value(
57        get_scratch_metas_fn: ScratchCallableType,
58    ) -> Callable[..., Dict[str, TensorSpec]]:
59        """
60        Adapt return value from a ScratchTensorMetadata to a TensorSpec
61        """
62
63        def wrapper(*args: TensorSpec, **kwargs: TensorSpec) -> Dict[str, TensorSpec]:
64            meta_dict = get_scratch_metas_fn(*args, **kwargs)
65            return {k: to_tensor_spec(v) for k, v in meta_dict.items()}
66
67        return wrapper
68
69    def wrapper(get_scratch_metas_fn: ScratchCallableType) -> ScratchCallableType:
70        # pyre-fixme[16]: `OpOverload` has no attribute `get_scratch_metas`.
71        out_variant.get_scratch_metas = adapt_return_value(get_scratch_metas_fn)
72        return get_scratch_metas_fn
73
74    return wrapper
75
76
77# pyre-ignore
78def attach_calculate_upper_bound_shape_fn(func_op: torch._ops.OpOverload):
79    """
80    The input is the OpOverload for the functional op.
81    """
82
83    # pyre-ignore
84    def wrapper(calculate_upper_bound_shape_fn):
85        # pyre-fixme[16]: `OpOverload` has no attribute `calculate_upper_bound_shape`.
86        func_op.calculate_upper_bound_shape = calculate_upper_bound_shape_fn
87        return calculate_upper_bound_shape_fn
88
89    return wrapper
90