# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch from transformers import PretrainedConfig, StaticCache class ETStaticCache(StaticCache): """ A customized static cache implementation, which overrides a few methods to make it exportable to ExecuTorch. This can be removed once transformers supports static cache for Phi3 properly. """ def __init__( self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32, ) -> None: super().__init__( config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype, ) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # pyre-fixme[16]: `ETStaticCache` has no attribute `key_cache`. return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item() def get_usable_length( self, new_seq_length: int, layer_idx: Optional[int] = 0 ) -> int: return self.get_seq_length(layer_idx)