1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10 11from executorch.examples.models.llama.llama_transformer import KVCache 12 13from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( 14 QuantizedCacheType, 15 QuantizedKVCache, 16) 17 18 19class QuantizedKVCacheTest(unittest.TestCase): 20 21 def _init_cache(self): 22 self.kv_cache = KVCache( 23 self.max_batch_size, 24 self.max_seq_len, 25 self.n_kv_heads, 26 self.head_dim, 27 self.transpose_kv_cache, 28 self.enable_dynamic_shape, 29 dtype=self.dtype, 30 ) 31 32 def _init_kv(self): 33 if self.transpose_kv_cache: 34 shape = (1, self.n_kv_heads, self.seq_len, self.head_dim) 35 else: 36 shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) 37 k = torch.rand(shape, dtype=self.dtype) 38 v = torch.rand(shape, dtype=self.dtype) 39 return k, v 40 41 def setUp(self): 42 torch.manual_seed(42) 43 self.max_batch_size = 1 44 self.max_seq_len = 5 45 self.n_kv_heads = 8 46 self.head_dim = 17 47 self.enable_dynamic_shape = False 48 self.transpose_kv_cache = False 49 self.dtype = torch.float32 50 51 def _test_simple_update_fetch(self, is_transposed=False, is_dynamic_shape=False): 52 self.transpose_kv_cache = is_transposed 53 self.enable_dynamic_shape = is_dynamic_shape 54 input_pos = torch.tensor([0, 1, 2]) 55 self.seq_len = input_pos.size(0) 56 self._init_cache() 57 k, v = self._init_kv() 58 quantized_kv_cache = QuantizedKVCache.from_float( 59 self.kv_cache, QuantizedCacheType.AffineAsymmetric 60 ) 61 updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v) 62 updated_dequantized_k_cache, updated_dequantized_v_cache = ( 63 quantized_kv_cache.update(input_pos, k, v) 64 ) 65 66 def index(t, input_pos): 67 if self.transpose_kv_cache: 68 return t[:, :, input_pos, :] 69 else: 70 return t[:, input_pos, :, :] 71 72 sliced_k_cache = index(updated_k_cache, input_pos) 73 sliced_v_cache = index(updated_v_cache, input_pos) 74 75 sliced_dequantized_k_cache = index(updated_dequantized_k_cache, input_pos) 76 sliced_dequantized_v_cache = index(updated_dequantized_v_cache, input_pos) 77 78 torch.testing.assert_close( 79 sliced_k_cache, 80 sliced_dequantized_k_cache, 81 rtol=1e-02, 82 atol=1e-02, 83 ) 84 torch.testing.assert_close( 85 sliced_v_cache, 86 sliced_dequantized_v_cache, 87 rtol=1e-02, 88 atol=1e-02, 89 ) 90 91 input_pos = torch.tensor([3]) 92 self.seq_len = input_pos.size(0) 93 k, v = self._init_kv() 94 pos_to_check = torch.tensor([0, 1, 2, 3]) 95 updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v) 96 updated_dequantized_k_cache, updated_dequantized_v_cache = ( 97 quantized_kv_cache.update(input_pos, k, v) 98 ) 99 sliced_k_cache = index(updated_k_cache, pos_to_check) 100 sliced_v_cache = index(updated_v_cache, pos_to_check) 101 102 sliced_dequantized_k_cache = index(updated_dequantized_k_cache, pos_to_check) 103 sliced_dequantized_v_cache = index(updated_dequantized_v_cache, pos_to_check) 104 105 torch.testing.assert_close( 106 sliced_k_cache, 107 sliced_dequantized_k_cache, 108 rtol=1e-02, 109 atol=1e-02, 110 ) 111 torch.testing.assert_close( 112 sliced_v_cache, 113 sliced_dequantized_v_cache, 114 rtol=1e-02, 115 atol=1e-02, 116 ) 117 118 def test_simple_update_fetch_not_transposed(self): 119 self._test_simple_update_fetch() 120 121 def test_simple_update_fetch_not_transposed_dynamic_shape(self): 122 self._test_simple_update_fetch(is_dynamic_shape=True) 123 124 def test_simple_update_fetch_transposed(self): 125 self._test_simple_update_fetch(is_transposed=True) 126 127 def test_simple_update_fetch_transposed_dynamic_shape(self): 128 self._test_simple_update_fetch(is_transposed=True, is_dynamic_shape=True) 129