1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from functools import lru_cache 8from typing import Dict, List 9 10import executorch.exir as exir 11import torch 12 13from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 14 15 16@lru_cache(maxsize=None) 17def _get_bilinear_2d_graphs(): 18 class bilinear2d(torch.nn.Module): 19 def __init__(self, align_corners): 20 super().__init__() 21 self.align_corners = align_corners 22 23 def forward(self, x): 24 return torch.nn.functional.interpolate( 25 x, 26 scale_factor=2, 27 mode="bilinear", 28 align_corners=self.align_corners, 29 antialias=False, 30 ) 31 32 sample_inputs = (torch.randn(1, 3, 4, 4),) 33 _bilinear2d_graphs = {} 34 capture_configs = [ 35 exir.CaptureConfig(enable_aot=True, _unlift=False), 36 exir.CaptureConfig(enable_aot=True, _unlift=True), 37 ] 38 for align_corners in [True, False]: 39 for config in capture_configs: 40 for skip_dim_order_flag in [True, False]: 41 edge = exir.capture( 42 bilinear2d(align_corners), sample_inputs, config 43 ).to_edge( 44 config=get_xnnpack_edge_compile_config( 45 skip_dim_order=skip_dim_order_flag 46 ) 47 ) 48 _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners 49 return _bilinear2d_graphs 50 51 52def get_graphs() -> List[torch.fx.GraphModule]: 53 return list(_get_bilinear_2d_graphs().keys()) 54 55 56def get_graphs_dict() -> Dict[torch.fx.GraphModule, bool]: 57 return _get_bilinear_2d_graphs() 58