xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/graphs/bilinear_2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from functools import lru_cache
8from typing import Dict, List
9
10import executorch.exir as exir
11import torch
12
13from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
14
15
16@lru_cache(maxsize=None)
17def _get_bilinear_2d_graphs():
18    class bilinear2d(torch.nn.Module):
19        def __init__(self, align_corners):
20            super().__init__()
21            self.align_corners = align_corners
22
23        def forward(self, x):
24            return torch.nn.functional.interpolate(
25                x,
26                scale_factor=2,
27                mode="bilinear",
28                align_corners=self.align_corners,
29                antialias=False,
30            )
31
32    sample_inputs = (torch.randn(1, 3, 4, 4),)
33    _bilinear2d_graphs = {}
34    capture_configs = [
35        exir.CaptureConfig(enable_aot=True, _unlift=False),
36        exir.CaptureConfig(enable_aot=True, _unlift=True),
37    ]
38    for align_corners in [True, False]:
39        for config in capture_configs:
40            for skip_dim_order_flag in [True, False]:
41                edge = exir.capture(
42                    bilinear2d(align_corners), sample_inputs, config
43                ).to_edge(
44                    config=get_xnnpack_edge_compile_config(
45                        skip_dim_order=skip_dim_order_flag
46                    )
47                )
48                _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners
49    return _bilinear2d_graphs
50
51
52def get_graphs() -> List[torch.fx.GraphModule]:
53    return list(_get_bilinear_2d_graphs().keys())
54
55
56def get_graphs_dict() -> Dict[torch.fx.GraphModule, bool]:
57    return _get_bilinear_2d_graphs()
58