1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from functools import lru_cache 8from typing import List, Optional 9 10import torch 11from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 12from executorch.exir import to_edge 13from torch import Tensor 14from torch.export import export 15 16 17@lru_cache(maxsize=None) 18def get_graphs() -> List[torch.fx.GraphModule]: 19 """ 20 Returns a list of SDPA graphs. 21 """ 22 23 class SDPA(torch.nn.Module): 24 def __init__(self): 25 super().__init__() 26 self.dropout_p: float = 0.0 27 self.is_causal: bool = False 28 self.scale: Optional[float] = None 29 30 def forward( 31 self, 32 q: Tensor, 33 k: Tensor, 34 v: Tensor, 35 attn_mask: Optional[Tensor] = None, 36 ): 37 return torch.nn.functional.scaled_dot_product_attention( 38 q, 39 k, 40 v, 41 attn_mask=mask, 42 dropout_p=self.dropout_p, 43 is_causal=self.is_causal, 44 scale=self.scale, 45 ) 46 47 batch_size = 8 48 heads = 16 49 seq_len = 32 50 dim = 64 51 52 q = torch.randn(batch_size, heads, seq_len, dim) 53 k = torch.randn(batch_size, heads, seq_len, dim) 54 v = torch.randn(batch_size, heads, seq_len, dim) 55 56 # TODO add support for, 57 # 1. None - mask should be inserted later on 58 # 2. >2d tensor - requires general unsqueeze from newer xnnpack 59 masks = [torch.full((seq_len, seq_len), 0, dtype=torch.float)] 60 61 graphs = [] 62 for mask in masks: 63 # These two seems to generate different graphs - P1136301928 64 for dtype in [torch.float, torch.float16]: 65 q = q.to(dtype) 66 k = k.to(dtype) 67 v = v.to(dtype) 68 mask = mask.to(dtype) 69 70 edge = to_edge( 71 export( 72 SDPA(), # pyre-ignore[16] 73 ( 74 q, 75 k, 76 v, 77 mask, 78 ), 79 ), 80 compile_config=get_xnnpack_edge_compile_config(), 81 ) 82 gm = edge.exported_program().graph_module 83 graphs.append(gm) 84 85 return graphs 86