1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10
11from executorch.examples.models.llama.llama_transformer import KVCache
12
13from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
14    QuantizedCacheType,
15    QuantizedKVCache,
16)
17
18
19class QuantizedKVCacheTest(unittest.TestCase):
20
21    def _init_cache(self):
22        self.kv_cache = KVCache(
23            self.max_batch_size,
24            self.max_seq_len,
25            self.n_kv_heads,
26            self.head_dim,
27            self.transpose_kv_cache,
28            self.enable_dynamic_shape,
29            dtype=self.dtype,
30        )
31
32    def _init_kv(self):
33        if self.transpose_kv_cache:
34            shape = (1, self.n_kv_heads, self.seq_len, self.head_dim)
35        else:
36            shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
37        k = torch.rand(shape, dtype=self.dtype)
38        v = torch.rand(shape, dtype=self.dtype)
39        return k, v
40
41    def setUp(self):
42        torch.manual_seed(42)
43        self.max_batch_size = 1
44        self.max_seq_len = 5
45        self.n_kv_heads = 8
46        self.head_dim = 17
47        self.enable_dynamic_shape = False
48        self.transpose_kv_cache = False
49        self.dtype = torch.float32
50
51    def _test_simple_update_fetch(self, is_transposed=False, is_dynamic_shape=False):
52        self.transpose_kv_cache = is_transposed
53        self.enable_dynamic_shape = is_dynamic_shape
54        input_pos = torch.tensor([0, 1, 2])
55        self.seq_len = input_pos.size(0)
56        self._init_cache()
57        k, v = self._init_kv()
58        quantized_kv_cache = QuantizedKVCache.from_float(
59            self.kv_cache, QuantizedCacheType.AffineAsymmetric
60        )
61        updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v)
62        updated_dequantized_k_cache, updated_dequantized_v_cache = (
63            quantized_kv_cache.update(input_pos, k, v)
64        )
65
66        def index(t, input_pos):
67            if self.transpose_kv_cache:
68                return t[:, :, input_pos, :]
69            else:
70                return t[:, input_pos, :, :]
71
72        sliced_k_cache = index(updated_k_cache, input_pos)
73        sliced_v_cache = index(updated_v_cache, input_pos)
74
75        sliced_dequantized_k_cache = index(updated_dequantized_k_cache, input_pos)
76        sliced_dequantized_v_cache = index(updated_dequantized_v_cache, input_pos)
77
78        torch.testing.assert_close(
79            sliced_k_cache,
80            sliced_dequantized_k_cache,
81            rtol=1e-02,
82            atol=1e-02,
83        )
84        torch.testing.assert_close(
85            sliced_v_cache,
86            sliced_dequantized_v_cache,
87            rtol=1e-02,
88            atol=1e-02,
89        )
90
91        input_pos = torch.tensor([3])
92        self.seq_len = input_pos.size(0)
93        k, v = self._init_kv()
94        pos_to_check = torch.tensor([0, 1, 2, 3])
95        updated_k_cache, updated_v_cache = self.kv_cache.update(input_pos, k, v)
96        updated_dequantized_k_cache, updated_dequantized_v_cache = (
97            quantized_kv_cache.update(input_pos, k, v)
98        )
99        sliced_k_cache = index(updated_k_cache, pos_to_check)
100        sliced_v_cache = index(updated_v_cache, pos_to_check)
101
102        sliced_dequantized_k_cache = index(updated_dequantized_k_cache, pos_to_check)
103        sliced_dequantized_v_cache = index(updated_dequantized_v_cache, pos_to_check)
104
105        torch.testing.assert_close(
106            sliced_k_cache,
107            sliced_dequantized_k_cache,
108            rtol=1e-02,
109            atol=1e-02,
110        )
111        torch.testing.assert_close(
112            sliced_v_cache,
113            sliced_dequantized_v_cache,
114            rtol=1e-02,
115            atol=1e-02,
116        )
117
118    def test_simple_update_fetch_not_transposed(self):
119        self._test_simple_update_fetch()
120
121    def test_simple_update_fetch_not_transposed_dynamic_shape(self):
122        self._test_simple_update_fetch(is_dynamic_shape=True)
123
124    def test_simple_update_fetch_transposed(self):
125        self._test_simple_update_fetch(is_transposed=True)
126
127    def test_simple_update_fetch_transposed_dynamic_shape(self):
128        self._test_simple_update_fetch(is_transposed=True, is_dynamic_shape=True)
129