xref: /aosp_15_r20/external/XNNPACK/src/cache.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include "xnnpack/cache.h"
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <assert.h> // For assert.
9*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h> // For size_t.
10*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h> // For uint32_t.
11*4bdc9457SAndroid Build Coastguard Worker 
12*4bdc9457SAndroid Build Coastguard Worker #include "xnnpack.h"
13*4bdc9457SAndroid Build Coastguard Worker #include "xnnpack/allocator.h"
14*4bdc9457SAndroid Build Coastguard Worker #include "xnnpack/log.h"
15*4bdc9457SAndroid Build Coastguard Worker #include "xnnpack/math.h"
16*4bdc9457SAndroid Build Coastguard Worker #include "xnnpack/mutex.h"
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #define XNN_CACHE_HASH_SEED 7
19*4bdc9457SAndroid Build Coastguard Worker #define XNN_CACHE_INITIAL_BUCKETS 32
20*4bdc9457SAndroid Build Coastguard Worker #define XNN_CACHE_MAX_LOAD 0.75
21*4bdc9457SAndroid Build Coastguard Worker // Max load factor is 0.75 (3/4), i.e. num_entries / num_buckets > 3 / 4.
22*4bdc9457SAndroid Build Coastguard Worker #define XNN_CACHE_MAX_LOAD_ENTRIES_MULTIPLIER 4
23*4bdc9457SAndroid Build Coastguard Worker #define XNN_CACHE_MAX_LOAD_BUCKETS_MULTIPLIER 3
24*4bdc9457SAndroid Build Coastguard Worker #define XNN_CACHE_GROWTH_FACTOR 2
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker // MurmurHash3 implementation, copied from smhasher, with minor modifications in
27*4bdc9457SAndroid Build Coastguard Worker // style and main loop.
28*4bdc9457SAndroid Build Coastguard Worker 
fmix32(uint32_t h)29*4bdc9457SAndroid Build Coastguard Worker static inline uint32_t fmix32(uint32_t h)
30*4bdc9457SAndroid Build Coastguard Worker {
31*4bdc9457SAndroid Build Coastguard Worker   h ^= h >> 16;
32*4bdc9457SAndroid Build Coastguard Worker   h *= UINT32_C(0x85EBCA6B);
33*4bdc9457SAndroid Build Coastguard Worker   h ^= h >> 13;
34*4bdc9457SAndroid Build Coastguard Worker   h *= UINT32_C(0xC2B2AE35);
35*4bdc9457SAndroid Build Coastguard Worker   h ^= h >> 16;
36*4bdc9457SAndroid Build Coastguard Worker 
37*4bdc9457SAndroid Build Coastguard Worker   return h;
38*4bdc9457SAndroid Build Coastguard Worker }
39*4bdc9457SAndroid Build Coastguard Worker 
murmur_hash3(const void * key,size_t len,uint32_t seed)40*4bdc9457SAndroid Build Coastguard Worker static uint32_t murmur_hash3(const void* key, size_t len, uint32_t seed)
41*4bdc9457SAndroid Build Coastguard Worker {
42*4bdc9457SAndroid Build Coastguard Worker   const uint8_t* data = (const uint8_t*) key;
43*4bdc9457SAndroid Build Coastguard Worker 
44*4bdc9457SAndroid Build Coastguard Worker   uint32_t h1 = seed;
45*4bdc9457SAndroid Build Coastguard Worker 
46*4bdc9457SAndroid Build Coastguard Worker   const uint32_t c1 = UINT32_C(0xCC9E2D51);
47*4bdc9457SAndroid Build Coastguard Worker   const uint32_t c2 = UINT32_C(0x1B873593);
48*4bdc9457SAndroid Build Coastguard Worker 
49*4bdc9457SAndroid Build Coastguard Worker   const uint32_t* blocks = (const uint32_t*) data;
50*4bdc9457SAndroid Build Coastguard Worker   for (; len >= sizeof(uint32_t); len -= sizeof(uint32_t)) {
51*4bdc9457SAndroid Build Coastguard Worker     uint32_t k1 = *blocks++;
52*4bdc9457SAndroid Build Coastguard Worker 
53*4bdc9457SAndroid Build Coastguard Worker     k1 *= c1;
54*4bdc9457SAndroid Build Coastguard Worker     k1 = math_rotl_u32(k1, 15);
55*4bdc9457SAndroid Build Coastguard Worker     k1 *= c2;
56*4bdc9457SAndroid Build Coastguard Worker 
57*4bdc9457SAndroid Build Coastguard Worker     h1 ^= k1;
58*4bdc9457SAndroid Build Coastguard Worker     h1 = math_rotl_u32(h1, 13);
59*4bdc9457SAndroid Build Coastguard Worker     h1 = h1 * 5 + UINT32_C(0xE6546B64);
60*4bdc9457SAndroid Build Coastguard Worker   }
61*4bdc9457SAndroid Build Coastguard Worker 
62*4bdc9457SAndroid Build Coastguard Worker   const uint8_t* tail = (const uint8_t*) blocks;
63*4bdc9457SAndroid Build Coastguard Worker 
64*4bdc9457SAndroid Build Coastguard Worker   uint32_t k1 = 0;
65*4bdc9457SAndroid Build Coastguard Worker 
66*4bdc9457SAndroid Build Coastguard Worker   switch (len & 3) {
67*4bdc9457SAndroid Build Coastguard Worker     case 3:
68*4bdc9457SAndroid Build Coastguard Worker       k1 ^= tail[2] << 16;
69*4bdc9457SAndroid Build Coastguard Worker     case 2:
70*4bdc9457SAndroid Build Coastguard Worker       k1 ^= tail[1] << 8;
71*4bdc9457SAndroid Build Coastguard Worker     case 1:
72*4bdc9457SAndroid Build Coastguard Worker       k1 ^= tail[0];
73*4bdc9457SAndroid Build Coastguard Worker       k1 *= c1;
74*4bdc9457SAndroid Build Coastguard Worker       k1 = math_rotl_u32(k1, 15);
75*4bdc9457SAndroid Build Coastguard Worker       k1 *= c2;
76*4bdc9457SAndroid Build Coastguard Worker       h1 ^= k1;
77*4bdc9457SAndroid Build Coastguard Worker   };
78*4bdc9457SAndroid Build Coastguard Worker 
79*4bdc9457SAndroid Build Coastguard Worker   h1 ^= len;
80*4bdc9457SAndroid Build Coastguard Worker 
81*4bdc9457SAndroid Build Coastguard Worker   return fmix32(h1);
82*4bdc9457SAndroid Build Coastguard Worker }
83*4bdc9457SAndroid Build Coastguard Worker 
84*4bdc9457SAndroid Build Coastguard Worker #ifndef NDEBUG
85*4bdc9457SAndroid Build Coastguard Worker // This function is only used by an assert, so do not include it in non-debug
86*4bdc9457SAndroid Build Coastguard Worker // builds.
cache_size(struct xnn_cache * cache)87*4bdc9457SAndroid Build Coastguard Worker static inline size_t cache_size(struct xnn_cache* cache) {
88*4bdc9457SAndroid Build Coastguard Worker   switch (cache->type) {
89*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_type_code:
90*4bdc9457SAndroid Build Coastguard Worker       return cache->code.size;
91*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_type_weights:
92*4bdc9457SAndroid Build Coastguard Worker       return cache->weights.size;
93*4bdc9457SAndroid Build Coastguard Worker     default:
94*4bdc9457SAndroid Build Coastguard Worker       XNN_UNREACHABLE;
95*4bdc9457SAndroid Build Coastguard Worker   }
96*4bdc9457SAndroid Build Coastguard Worker   return SIZE_MAX;
97*4bdc9457SAndroid Build Coastguard Worker }
98*4bdc9457SAndroid Build Coastguard Worker #endif
99*4bdc9457SAndroid Build Coastguard Worker 
cache_start(struct xnn_cache * cache)100*4bdc9457SAndroid Build Coastguard Worker static inline void* cache_start(struct xnn_cache* cache) {
101*4bdc9457SAndroid Build Coastguard Worker   switch (cache->type) {
102*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_type_code:
103*4bdc9457SAndroid Build Coastguard Worker       return cache->code.start;
104*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_type_weights:
105*4bdc9457SAndroid Build Coastguard Worker       return cache->weights.start;
106*4bdc9457SAndroid Build Coastguard Worker     default:
107*4bdc9457SAndroid Build Coastguard Worker       XNN_UNREACHABLE;
108*4bdc9457SAndroid Build Coastguard Worker   }
109*4bdc9457SAndroid Build Coastguard Worker   return NULL;
110*4bdc9457SAndroid Build Coastguard Worker }
111*4bdc9457SAndroid Build Coastguard Worker 
xnn_init_cache_with_size(struct xnn_cache * cache,size_t num_buckets,enum xnn_cache_type cache_type)112*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_init_cache_with_size(struct xnn_cache* cache, size_t num_buckets, enum xnn_cache_type cache_type)
113*4bdc9457SAndroid Build Coastguard Worker {
114*4bdc9457SAndroid Build Coastguard Worker   memset(cache, 0, sizeof(struct xnn_cache));
115*4bdc9457SAndroid Build Coastguard Worker   cache->buckets = (struct xnn_cache_bucket*) xnn_allocate_zero_memory(num_buckets * sizeof(struct xnn_cache_bucket));
116*4bdc9457SAndroid Build Coastguard Worker   if (cache->buckets == NULL) {
117*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error("fail to allocate memory for cache buckets");
118*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_out_of_memory;
119*4bdc9457SAndroid Build Coastguard Worker   }
120*4bdc9457SAndroid Build Coastguard Worker 
121*4bdc9457SAndroid Build Coastguard Worker   cache->type = cache_type;
122*4bdc9457SAndroid Build Coastguard Worker   cache->num_buckets = num_buckets;
123*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
124*4bdc9457SAndroid Build Coastguard Worker }
125*4bdc9457SAndroid Build Coastguard Worker 
xnn_init_code_cache_with_size(struct xnn_code_cache * cache,size_t num_buckets)126*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_init_code_cache_with_size(struct xnn_code_cache* cache, size_t num_buckets)
127*4bdc9457SAndroid Build Coastguard Worker {
128*4bdc9457SAndroid Build Coastguard Worker   memset(cache, 0, sizeof(struct xnn_code_cache));
129*4bdc9457SAndroid Build Coastguard Worker   enum xnn_status status = xnn_status_success;
130*4bdc9457SAndroid Build Coastguard Worker   status = xnn_init_cache_with_size(&cache->cache, num_buckets, xnn_cache_type_code);
131*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
132*4bdc9457SAndroid Build Coastguard Worker     goto error;
133*4bdc9457SAndroid Build Coastguard Worker   }
134*4bdc9457SAndroid Build Coastguard Worker 
135*4bdc9457SAndroid Build Coastguard Worker   status = xnn_allocate_code_memory(&cache->cache.code, XNN_DEFAULT_CODE_BUFFER_SIZE);
136*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
137*4bdc9457SAndroid Build Coastguard Worker     goto error;
138*4bdc9457SAndroid Build Coastguard Worker   }
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
141*4bdc9457SAndroid Build Coastguard Worker 
142*4bdc9457SAndroid Build Coastguard Worker error:
143*4bdc9457SAndroid Build Coastguard Worker   xnn_release_code_cache(cache);
144*4bdc9457SAndroid Build Coastguard Worker   return status;
145*4bdc9457SAndroid Build Coastguard Worker }
146*4bdc9457SAndroid Build Coastguard Worker 
xnn_init_code_cache(struct xnn_code_cache * cache)147*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_init_code_cache(struct xnn_code_cache* cache)
148*4bdc9457SAndroid Build Coastguard Worker {
149*4bdc9457SAndroid Build Coastguard Worker   return xnn_init_code_cache_with_size(cache, XNN_CACHE_INITIAL_BUCKETS);
150*4bdc9457SAndroid Build Coastguard Worker }
151*4bdc9457SAndroid Build Coastguard Worker 
cache_buckets_grow(struct xnn_cache * cache)152*4bdc9457SAndroid Build Coastguard Worker static bool cache_buckets_grow(struct xnn_cache* cache)
153*4bdc9457SAndroid Build Coastguard Worker {
154*4bdc9457SAndroid Build Coastguard Worker   const size_t new_num_buckets = cache->num_buckets * XNN_CACHE_GROWTH_FACTOR;
155*4bdc9457SAndroid Build Coastguard Worker   assert(is_po2(new_num_buckets));
156*4bdc9457SAndroid Build Coastguard Worker   struct xnn_cache tmp_cache;
157*4bdc9457SAndroid Build Coastguard Worker   xnn_init_cache_with_size(&tmp_cache, new_num_buckets, cache->type);
158*4bdc9457SAndroid Build Coastguard Worker 
159*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < cache->num_buckets; i++) {
160*4bdc9457SAndroid Build Coastguard Worker     struct xnn_cache_bucket b = cache->buckets[i];
161*4bdc9457SAndroid Build Coastguard Worker     if (b.size == 0) {
162*4bdc9457SAndroid Build Coastguard Worker       continue;
163*4bdc9457SAndroid Build Coastguard Worker     }
164*4bdc9457SAndroid Build Coastguard Worker 
165*4bdc9457SAndroid Build Coastguard Worker     // Find the first empty slot by linear probing to insert. No need to check
166*4bdc9457SAndroid Build Coastguard Worker     // hashes since we are not looking up anything, just moving things around
167*4bdc9457SAndroid Build Coastguard Worker     // into a bigger hash table.
168*4bdc9457SAndroid Build Coastguard Worker     const size_t mask = tmp_cache.num_buckets - 1;
169*4bdc9457SAndroid Build Coastguard Worker     size_t idx = b.hash & mask;
170*4bdc9457SAndroid Build Coastguard Worker     while (tmp_cache.buckets[idx].size != 0) {
171*4bdc9457SAndroid Build Coastguard Worker       idx = (idx + 1) & mask;
172*4bdc9457SAndroid Build Coastguard Worker     }
173*4bdc9457SAndroid Build Coastguard Worker     tmp_cache.buckets[idx].hash = b.hash;
174*4bdc9457SAndroid Build Coastguard Worker     tmp_cache.buckets[idx].size = b.size;
175*4bdc9457SAndroid Build Coastguard Worker     tmp_cache.buckets[idx].offset = b.offset;
176*4bdc9457SAndroid Build Coastguard Worker   }
177*4bdc9457SAndroid Build Coastguard Worker 
178*4bdc9457SAndroid Build Coastguard Worker   xnn_release_memory(cache->buckets);
179*4bdc9457SAndroid Build Coastguard Worker 
180*4bdc9457SAndroid Build Coastguard Worker   cache->buckets = tmp_cache.buckets;
181*4bdc9457SAndroid Build Coastguard Worker   cache->num_buckets = tmp_cache.num_buckets;
182*4bdc9457SAndroid Build Coastguard Worker   return true;
183*4bdc9457SAndroid Build Coastguard Worker }
184*4bdc9457SAndroid Build Coastguard Worker 
bytes_equal(struct xnn_cache * cache,void * ptr,size_t size,size_t offset)185*4bdc9457SAndroid Build Coastguard Worker static inline bool bytes_equal(struct xnn_cache* cache, void* ptr, size_t size, size_t offset)
186*4bdc9457SAndroid Build Coastguard Worker {
187*4bdc9457SAndroid Build Coastguard Worker   return memcmp(ptr, (void*) ((uintptr_t) cache_start(cache) + offset), size) == 0;
188*4bdc9457SAndroid Build Coastguard Worker }
189*4bdc9457SAndroid Build Coastguard Worker 
lookup(struct xnn_cache * cache,void * ptr,size_t size,uint32_t hash,size_t * index)190*4bdc9457SAndroid Build Coastguard Worker static bool lookup(struct xnn_cache* cache, void* ptr, size_t size, uint32_t hash, size_t* index)
191*4bdc9457SAndroid Build Coastguard Worker {
192*4bdc9457SAndroid Build Coastguard Worker   assert(is_po2(cache->num_buckets));
193*4bdc9457SAndroid Build Coastguard Worker   const size_t mask = cache->num_buckets - 1;
194*4bdc9457SAndroid Build Coastguard Worker   size_t idx = hash & mask;
195*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_cache_bucket* buckets = cache->buckets;
196*4bdc9457SAndroid Build Coastguard Worker 
197*4bdc9457SAndroid Build Coastguard Worker   // Linear probing.
198*4bdc9457SAndroid Build Coastguard Worker   while (buckets[idx].size != 0 &&
199*4bdc9457SAndroid Build Coastguard Worker          !(buckets[idx].hash == hash &&
200*4bdc9457SAndroid Build Coastguard Worker            size == buckets[idx].size &&
201*4bdc9457SAndroid Build Coastguard Worker            bytes_equal(cache, ptr, buckets[idx].size, buckets[idx].offset))) {
202*4bdc9457SAndroid Build Coastguard Worker     idx = (idx + 1) & mask;
203*4bdc9457SAndroid Build Coastguard Worker   }
204*4bdc9457SAndroid Build Coastguard Worker   *index = idx;
205*4bdc9457SAndroid Build Coastguard Worker   if (buckets[idx].size == 0) {
206*4bdc9457SAndroid Build Coastguard Worker     return false;
207*4bdc9457SAndroid Build Coastguard Worker   } else {
208*4bdc9457SAndroid Build Coastguard Worker     return true;
209*4bdc9457SAndroid Build Coastguard Worker   }
210*4bdc9457SAndroid Build Coastguard Worker }
211*4bdc9457SAndroid Build Coastguard Worker 
insert(struct xnn_cache * cache,void * ptr,size_t size)212*4bdc9457SAndroid Build Coastguard Worker static bool insert(struct xnn_cache* cache, void* ptr, size_t size)
213*4bdc9457SAndroid Build Coastguard Worker {
214*4bdc9457SAndroid Build Coastguard Worker   const uint32_t hash = murmur_hash3(ptr, size, /*seed=*/XNN_CACHE_HASH_SEED);
215*4bdc9457SAndroid Build Coastguard Worker   size_t idx;
216*4bdc9457SAndroid Build Coastguard Worker   const bool found = lookup(cache, ptr, size, hash, &idx);
217*4bdc9457SAndroid Build Coastguard Worker   if (found) {
218*4bdc9457SAndroid Build Coastguard Worker     return false;
219*4bdc9457SAndroid Build Coastguard Worker   }
220*4bdc9457SAndroid Build Coastguard Worker 
221*4bdc9457SAndroid Build Coastguard Worker   // Ensure we have enough buckets to keep under our load limit.
222*4bdc9457SAndroid Build Coastguard Worker   if (cache->num_entries * XNN_CACHE_MAX_LOAD_ENTRIES_MULTIPLIER >
223*4bdc9457SAndroid Build Coastguard Worker       cache->num_buckets * XNN_CACHE_MAX_LOAD_BUCKETS_MULTIPLIER) {
224*4bdc9457SAndroid Build Coastguard Worker     if (!cache_buckets_grow(cache)) {
225*4bdc9457SAndroid Build Coastguard Worker       // Can't grow hash table anymore.
226*4bdc9457SAndroid Build Coastguard Worker       xnn_log_error("failed to grow cache buckets");
227*4bdc9457SAndroid Build Coastguard Worker       return false;
228*4bdc9457SAndroid Build Coastguard Worker     }
229*4bdc9457SAndroid Build Coastguard Worker     xnn_log_debug("successfully grew cache buckets");
230*4bdc9457SAndroid Build Coastguard Worker 
231*4bdc9457SAndroid Build Coastguard Worker     // If the cache grew, idx is stale, since that is based on the old cache's num_buckets.
232*4bdc9457SAndroid Build Coastguard Worker     const bool found_in_grown_cache = lookup(cache, ptr, size, hash, &idx);
233*4bdc9457SAndroid Build Coastguard Worker     assert(!found_in_grown_cache);
234*4bdc9457SAndroid Build Coastguard Worker     (void) found_in_grown_cache;  // Silence unused variable warnings.
235*4bdc9457SAndroid Build Coastguard Worker   }
236*4bdc9457SAndroid Build Coastguard Worker 
237*4bdc9457SAndroid Build Coastguard Worker   // Check that ptr points into cache's buffer.
238*4bdc9457SAndroid Build Coastguard Worker   assert((uintptr_t) ptr >= (uintptr_t) cache_start(cache));
239*4bdc9457SAndroid Build Coastguard Worker   if (cache->type == xnn_cache_type_code) {
240*4bdc9457SAndroid Build Coastguard Worker     assert((uintptr_t) ptr < (uintptr_t) cache_start(cache) + cache_size(cache));
241*4bdc9457SAndroid Build Coastguard Worker   }
242*4bdc9457SAndroid Build Coastguard Worker 
243*4bdc9457SAndroid Build Coastguard Worker   const size_t offset = (uintptr_t) ptr - (uintptr_t) cache_start(cache);
244*4bdc9457SAndroid Build Coastguard Worker 
245*4bdc9457SAndroid Build Coastguard Worker   // Insert the entry.
246*4bdc9457SAndroid Build Coastguard Worker   cache->buckets[idx].size = size;
247*4bdc9457SAndroid Build Coastguard Worker   cache->buckets[idx].hash = hash;
248*4bdc9457SAndroid Build Coastguard Worker   cache->buckets[idx].offset = offset;
249*4bdc9457SAndroid Build Coastguard Worker   cache->num_entries++;
250*4bdc9457SAndroid Build Coastguard Worker   return true;
251*4bdc9457SAndroid Build Coastguard Worker }
252*4bdc9457SAndroid Build Coastguard Worker 
253*4bdc9457SAndroid Build Coastguard Worker // Checks if a generated microkernel is already in the cache, returns the offset
254*4bdc9457SAndroid Build Coastguard Worker // if found, XNN_CACHE_NOT_FOUND otherwise.
lookup_cache(struct xnn_cache * cache,void * ptr,size_t size)255*4bdc9457SAndroid Build Coastguard Worker static size_t lookup_cache(struct xnn_cache* cache, void* ptr, size_t size)
256*4bdc9457SAndroid Build Coastguard Worker {
257*4bdc9457SAndroid Build Coastguard Worker   const uint32_t hash = murmur_hash3(ptr, size, /*seed=*/XNN_CACHE_HASH_SEED);
258*4bdc9457SAndroid Build Coastguard Worker   size_t bucket_idx;
259*4bdc9457SAndroid Build Coastguard Worker   if (lookup(cache, ptr, size, hash, &bucket_idx)) {
260*4bdc9457SAndroid Build Coastguard Worker     cache->hits++;
261*4bdc9457SAndroid Build Coastguard Worker     return cache->buckets[bucket_idx].offset;
262*4bdc9457SAndroid Build Coastguard Worker   } else {
263*4bdc9457SAndroid Build Coastguard Worker     cache->misses++;
264*4bdc9457SAndroid Build Coastguard Worker     return XNN_CACHE_NOT_FOUND;
265*4bdc9457SAndroid Build Coastguard Worker   }
266*4bdc9457SAndroid Build Coastguard Worker }
267*4bdc9457SAndroid Build Coastguard Worker 
xnn_get_or_insert_cache(struct xnn_cache * cache,void * ptr,size_t size)268*4bdc9457SAndroid Build Coastguard Worker size_t xnn_get_or_insert_cache(struct xnn_cache* cache, void* ptr, size_t size)
269*4bdc9457SAndroid Build Coastguard Worker {
270*4bdc9457SAndroid Build Coastguard Worker   const size_t found_offset = lookup_cache(cache, ptr, size);
271*4bdc9457SAndroid Build Coastguard Worker   if (found_offset != XNN_CACHE_NOT_FOUND) {
272*4bdc9457SAndroid Build Coastguard Worker     if (cache->type == xnn_cache_type_code) {
273*4bdc9457SAndroid Build Coastguard Worker       // Found in the cache, rewind the buffer because code generators update buffer size.
274*4bdc9457SAndroid Build Coastguard Worker       cache->code.size -= size;
275*4bdc9457SAndroid Build Coastguard Worker     }
276*4bdc9457SAndroid Build Coastguard Worker     return found_offset;
277*4bdc9457SAndroid Build Coastguard Worker   }
278*4bdc9457SAndroid Build Coastguard Worker 
279*4bdc9457SAndroid Build Coastguard Worker   if (cache->type == xnn_cache_type_weights) {
280*4bdc9457SAndroid Build Coastguard Worker     // Cache miss, weights packing functions don't update buffer size, update it here.
281*4bdc9457SAndroid Build Coastguard Worker     cache->weights.size += size;
282*4bdc9457SAndroid Build Coastguard Worker   }
283*4bdc9457SAndroid Build Coastguard Worker 
284*4bdc9457SAndroid Build Coastguard Worker   const size_t offset = (uintptr_t) ptr - (uintptr_t) cache_start(cache);
285*4bdc9457SAndroid Build Coastguard Worker   if (!insert(cache, ptr, size)) {
286*4bdc9457SAndroid Build Coastguard Worker     return XNN_CACHE_NOT_FOUND;
287*4bdc9457SAndroid Build Coastguard Worker   }
288*4bdc9457SAndroid Build Coastguard Worker   return offset;
289*4bdc9457SAndroid Build Coastguard Worker }
290*4bdc9457SAndroid Build Coastguard Worker 
xnn_get_or_insert_code_cache(struct xnn_code_cache * cache,void * ptr,size_t size)291*4bdc9457SAndroid Build Coastguard Worker size_t xnn_get_or_insert_code_cache(struct xnn_code_cache* cache, void* ptr, size_t size)
292*4bdc9457SAndroid Build Coastguard Worker {
293*4bdc9457SAndroid Build Coastguard Worker   return xnn_get_or_insert_cache(&cache->cache, ptr, size);
294*4bdc9457SAndroid Build Coastguard Worker }
295*4bdc9457SAndroid Build Coastguard Worker 
xnn_release_code_cache(struct xnn_code_cache * cache)296*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_release_code_cache(struct xnn_code_cache* cache)
297*4bdc9457SAndroid Build Coastguard Worker {
298*4bdc9457SAndroid Build Coastguard Worker   if XNN_LIKELY(cache != NULL) {
299*4bdc9457SAndroid Build Coastguard Worker     assert(cache->cache.type == xnn_cache_type_code);
300*4bdc9457SAndroid Build Coastguard Worker     xnn_release_code_memory(&cache->cache.code);
301*4bdc9457SAndroid Build Coastguard Worker     xnn_release_memory(cache->cache.buckets);
302*4bdc9457SAndroid Build Coastguard Worker   }
303*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
304*4bdc9457SAndroid Build Coastguard Worker }
305*4bdc9457SAndroid Build Coastguard Worker 
xnn_internal_init_weights_cache(struct xnn_weights_cache * cache,size_t num_buckets,size_t buffer_size)306*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_internal_init_weights_cache(
307*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache* cache,
308*4bdc9457SAndroid Build Coastguard Worker   size_t num_buckets,
309*4bdc9457SAndroid Build Coastguard Worker   size_t buffer_size)
310*4bdc9457SAndroid Build Coastguard Worker {
311*4bdc9457SAndroid Build Coastguard Worker   memset(cache, 0, sizeof(struct xnn_weights_cache));
312*4bdc9457SAndroid Build Coastguard Worker 
313*4bdc9457SAndroid Build Coastguard Worker   enum xnn_status status = xnn_status_success;
314*4bdc9457SAndroid Build Coastguard Worker   status = xnn_init_cache_with_size(&cache->cache, num_buckets, xnn_cache_type_weights);
315*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
316*4bdc9457SAndroid Build Coastguard Worker     goto error;
317*4bdc9457SAndroid Build Coastguard Worker   }
318*4bdc9457SAndroid Build Coastguard Worker 
319*4bdc9457SAndroid Build Coastguard Worker   status = xnn_allocate_weights_memory(&cache->cache.weights, buffer_size);
320*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
321*4bdc9457SAndroid Build Coastguard Worker     goto error;
322*4bdc9457SAndroid Build Coastguard Worker   }
323*4bdc9457SAndroid Build Coastguard Worker 
324*4bdc9457SAndroid Build Coastguard Worker   status = xnn_mutex_init(&cache->mutex);
325*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
326*4bdc9457SAndroid Build Coastguard Worker     goto error;
327*4bdc9457SAndroid Build Coastguard Worker   }
328*4bdc9457SAndroid Build Coastguard Worker 
329*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
330*4bdc9457SAndroid Build Coastguard Worker 
331*4bdc9457SAndroid Build Coastguard Worker error:
332*4bdc9457SAndroid Build Coastguard Worker   xnn_release_weights_cache(cache);
333*4bdc9457SAndroid Build Coastguard Worker   return status;
334*4bdc9457SAndroid Build Coastguard Worker }
335*4bdc9457SAndroid Build Coastguard Worker 
xnn_init_weights_cache_with_size(struct xnn_weights_cache * cache,size_t size)336*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_init_weights_cache_with_size(struct xnn_weights_cache* cache, size_t size)
337*4bdc9457SAndroid Build Coastguard Worker {
338*4bdc9457SAndroid Build Coastguard Worker   return xnn_internal_init_weights_cache(cache, XNN_CACHE_INITIAL_BUCKETS, size);
339*4bdc9457SAndroid Build Coastguard Worker }
340*4bdc9457SAndroid Build Coastguard Worker 
xnn_init_weights_cache(struct xnn_weights_cache * cache)341*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_init_weights_cache(struct xnn_weights_cache* cache)
342*4bdc9457SAndroid Build Coastguard Worker {
343*4bdc9457SAndroid Build Coastguard Worker   return xnn_init_weights_cache_with_size(cache, XNN_DEFAULT_WEIGHTS_BUFFER_SIZE);
344*4bdc9457SAndroid Build Coastguard Worker }
345*4bdc9457SAndroid Build Coastguard Worker 
xnn_finalize_weights_cache(struct xnn_weights_cache * cache,enum xnn_weights_cache_finalization_kind finalization_kind)346*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_finalize_weights_cache(
347*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache* cache,
348*4bdc9457SAndroid Build Coastguard Worker   enum xnn_weights_cache_finalization_kind finalization_kind)
349*4bdc9457SAndroid Build Coastguard Worker {
350*4bdc9457SAndroid Build Coastguard Worker   switch (cache->finalization_state) {
351*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_hard_finalized:
352*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_soft_finalized:
353*4bdc9457SAndroid Build Coastguard Worker       xnn_log_error("failed to finalize an already final weights cache");
354*4bdc9457SAndroid Build Coastguard Worker       return xnn_status_invalid_state;
355*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_not_finalized: {
356*4bdc9457SAndroid Build Coastguard Worker       enum xnn_status status;
357*4bdc9457SAndroid Build Coastguard Worker       enum xnn_cache_state finalized_state;
358*4bdc9457SAndroid Build Coastguard Worker 
359*4bdc9457SAndroid Build Coastguard Worker       if (finalization_kind == xnn_weights_cache_finalization_kind_hard) {
360*4bdc9457SAndroid Build Coastguard Worker         xnn_log_debug("hard finalizing weights cache");
361*4bdc9457SAndroid Build Coastguard Worker         status = xnn_finalize_weights_memory(&cache->cache.weights);
362*4bdc9457SAndroid Build Coastguard Worker         // Also release the memory used by hash table (but not the weights memory).
363*4bdc9457SAndroid Build Coastguard Worker         xnn_release_memory(cache->cache.buckets);
364*4bdc9457SAndroid Build Coastguard Worker         cache->cache.buckets = NULL;
365*4bdc9457SAndroid Build Coastguard Worker         finalized_state = xnn_cache_state_hard_finalized;
366*4bdc9457SAndroid Build Coastguard Worker       } else {
367*4bdc9457SAndroid Build Coastguard Worker         xnn_log_debug("soft finalizing weights cache");
368*4bdc9457SAndroid Build Coastguard Worker         assert(finalization_kind == xnn_weights_cache_finalization_kind_soft);
369*4bdc9457SAndroid Build Coastguard Worker         // Finalize weights cache by reserving sufficient space for the insertion of the largest cached weights. This
370*4bdc9457SAndroid Build Coastguard Worker         // ensures that we have space to write packed weights to check for cache hits without growing and moving the
371*4bdc9457SAndroid Build Coastguard Worker         // memory. This has some memory overhead, which can be as large as the size of the largest cached weights,
372*4bdc9457SAndroid Build Coastguard Worker         // rounded up to page size.
373*4bdc9457SAndroid Build Coastguard Worker         status = xnn_reserve_weights_memory(&cache->cache.weights, cache->max_weights_size);
374*4bdc9457SAndroid Build Coastguard Worker         finalized_state = xnn_cache_state_soft_finalized;
375*4bdc9457SAndroid Build Coastguard Worker       }
376*4bdc9457SAndroid Build Coastguard Worker       if (status != xnn_status_success) {
377*4bdc9457SAndroid Build Coastguard Worker         xnn_log_error("failed to finalize weights cache memory");
378*4bdc9457SAndroid Build Coastguard Worker         return xnn_status_invalid_state;
379*4bdc9457SAndroid Build Coastguard Worker       }
380*4bdc9457SAndroid Build Coastguard Worker 
381*4bdc9457SAndroid Build Coastguard Worker       cache->finalization_state = finalized_state;
382*4bdc9457SAndroid Build Coastguard Worker       return xnn_status_success;
383*4bdc9457SAndroid Build Coastguard Worker     }
384*4bdc9457SAndroid Build Coastguard Worker   }
385*4bdc9457SAndroid Build Coastguard Worker }
386*4bdc9457SAndroid Build Coastguard Worker 
xnn_release_weights_cache(struct xnn_weights_cache * cache)387*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_release_weights_cache(struct xnn_weights_cache* cache)
388*4bdc9457SAndroid Build Coastguard Worker {
389*4bdc9457SAndroid Build Coastguard Worker   if XNN_LIKELY(cache != NULL) {
390*4bdc9457SAndroid Build Coastguard Worker     assert(cache->cache.type == xnn_cache_type_weights);
391*4bdc9457SAndroid Build Coastguard Worker     xnn_release_weights_memory(&cache->cache.weights);
392*4bdc9457SAndroid Build Coastguard Worker     if (cache->cache.buckets != NULL) {
393*4bdc9457SAndroid Build Coastguard Worker       xnn_release_memory(cache->cache.buckets);
394*4bdc9457SAndroid Build Coastguard Worker     }
395*4bdc9457SAndroid Build Coastguard Worker     const enum xnn_status status = xnn_mutex_destroy(&cache->mutex);
396*4bdc9457SAndroid Build Coastguard Worker     if (status != xnn_status_success) {
397*4bdc9457SAndroid Build Coastguard Worker       return status;
398*4bdc9457SAndroid Build Coastguard Worker     }
399*4bdc9457SAndroid Build Coastguard Worker   }
400*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
401*4bdc9457SAndroid Build Coastguard Worker }
402*4bdc9457SAndroid Build Coastguard Worker 
cache_has_space(struct xnn_weights_cache * cache,size_t n)403*4bdc9457SAndroid Build Coastguard Worker static inline bool cache_has_space(struct xnn_weights_cache* cache, size_t n)
404*4bdc9457SAndroid Build Coastguard Worker {
405*4bdc9457SAndroid Build Coastguard Worker   const struct xnn_weights_buffer buf = cache->cache.weights;
406*4bdc9457SAndroid Build Coastguard Worker   return buf.size + n <= buf.capacity;
407*4bdc9457SAndroid Build Coastguard Worker }
408*4bdc9457SAndroid Build Coastguard Worker 
xnn_reserve_space_in_weights_cache(struct xnn_weights_cache * cache,size_t n)409*4bdc9457SAndroid Build Coastguard Worker void* xnn_reserve_space_in_weights_cache(struct xnn_weights_cache* cache, size_t n) {
410*4bdc9457SAndroid Build Coastguard Worker   switch (cache->finalization_state) {
411*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_hard_finalized:
412*4bdc9457SAndroid Build Coastguard Worker       xnn_log_error("cannot reserve additional space in a finalized compact weights cache");
413*4bdc9457SAndroid Build Coastguard Worker       return NULL;
414*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_soft_finalized:
415*4bdc9457SAndroid Build Coastguard Worker       if (!cache_has_space(cache, n)) {
416*4bdc9457SAndroid Build Coastguard Worker         xnn_log_error("cannot reserve additional space in a finalized weights cache");
417*4bdc9457SAndroid Build Coastguard Worker         return NULL;
418*4bdc9457SAndroid Build Coastguard Worker       }
419*4bdc9457SAndroid Build Coastguard Worker       // If the cache is finalized, and has space for `n` bytes, we still want to lock the mutex, because we can have
420*4bdc9457SAndroid Build Coastguard Worker       // multiple writers attempting to write to this space.
421*4bdc9457SAndroid Build Coastguard Worker       break;
422*4bdc9457SAndroid Build Coastguard Worker     default:
423*4bdc9457SAndroid Build Coastguard Worker       break;
424*4bdc9457SAndroid Build Coastguard Worker   }
425*4bdc9457SAndroid Build Coastguard Worker 
426*4bdc9457SAndroid Build Coastguard Worker   enum xnn_status status = xnn_mutex_lock(&cache->mutex);
427*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
428*4bdc9457SAndroid Build Coastguard Worker     return NULL;
429*4bdc9457SAndroid Build Coastguard Worker   }
430*4bdc9457SAndroid Build Coastguard Worker 
431*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_buffer* buffer = &cache->cache.weights;
432*4bdc9457SAndroid Build Coastguard Worker   status = xnn_reserve_weights_memory(buffer, n);
433*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
434*4bdc9457SAndroid Build Coastguard Worker     xnn_mutex_unlock(&cache->mutex);
435*4bdc9457SAndroid Build Coastguard Worker     return NULL;
436*4bdc9457SAndroid Build Coastguard Worker   }
437*4bdc9457SAndroid Build Coastguard Worker 
438*4bdc9457SAndroid Build Coastguard Worker   return (void*) ((uintptr_t) buffer->start + buffer->size);
439*4bdc9457SAndroid Build Coastguard Worker }
440*4bdc9457SAndroid Build Coastguard Worker 
xnn_get_or_insert_weights_cache(struct xnn_weights_cache * cache,void * ptr,size_t size)441*4bdc9457SAndroid Build Coastguard Worker size_t xnn_get_or_insert_weights_cache(struct xnn_weights_cache* cache, void* ptr, size_t size)
442*4bdc9457SAndroid Build Coastguard Worker {
443*4bdc9457SAndroid Build Coastguard Worker   size_t offset = XNN_CACHE_NOT_FOUND;
444*4bdc9457SAndroid Build Coastguard Worker 
445*4bdc9457SAndroid Build Coastguard Worker   switch (cache->finalization_state) {
446*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_hard_finalized: {
447*4bdc9457SAndroid Build Coastguard Worker       xnn_log_error("cannot insert into a finalized compact weights cache");
448*4bdc9457SAndroid Build Coastguard Worker       return XNN_CACHE_NOT_FOUND;
449*4bdc9457SAndroid Build Coastguard Worker     }
450*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_soft_finalized: {
451*4bdc9457SAndroid Build Coastguard Worker       // Inserting into a finalized weights cache is okay as long as:
452*4bdc9457SAndroid Build Coastguard Worker       // 1. there is sufficient space in the memory (to write the incoming packed weights), or
453*4bdc9457SAndroid Build Coastguard Worker       // 2. incoming packed weights is already in cache
454*4bdc9457SAndroid Build Coastguard Worker       if (!cache_has_space(cache, size)) {
455*4bdc9457SAndroid Build Coastguard Worker         xnn_log_error("insufficient extra space in finalized weights cache buffer");
456*4bdc9457SAndroid Build Coastguard Worker         return XNN_CACHE_NOT_FOUND;
457*4bdc9457SAndroid Build Coastguard Worker       }
458*4bdc9457SAndroid Build Coastguard Worker 
459*4bdc9457SAndroid Build Coastguard Worker       // We need to release the mutex from this point onwards, because xnn_reserve_space_in_weights would have returned
460*4bdc9457SAndroid Build Coastguard Worker       // non-NULL (which means that it locked the mutex).
461*4bdc9457SAndroid Build Coastguard Worker       const size_t found_offset = lookup_cache(&cache->cache, ptr, size);
462*4bdc9457SAndroid Build Coastguard Worker       if (found_offset == XNN_CACHE_NOT_FOUND) {
463*4bdc9457SAndroid Build Coastguard Worker         xnn_log_error("packed weights not found in finalized weights cache");
464*4bdc9457SAndroid Build Coastguard Worker       }
465*4bdc9457SAndroid Build Coastguard Worker 
466*4bdc9457SAndroid Build Coastguard Worker       offset = found_offset;
467*4bdc9457SAndroid Build Coastguard Worker       break;
468*4bdc9457SAndroid Build Coastguard Worker     }
469*4bdc9457SAndroid Build Coastguard Worker     case xnn_cache_state_not_finalized: {
470*4bdc9457SAndroid Build Coastguard Worker       offset = xnn_get_or_insert_cache(&cache->cache, ptr, size);
471*4bdc9457SAndroid Build Coastguard Worker       if (offset != XNN_CACHE_NOT_FOUND) {
472*4bdc9457SAndroid Build Coastguard Worker         // Found or inserted packed weights, update the largest size seen so far, this will be used when finalizing the
473*4bdc9457SAndroid Build Coastguard Worker         // weights cache, to ensure there is an extra space at the end for future cache checks.
474*4bdc9457SAndroid Build Coastguard Worker         cache->max_weights_size = max(size, cache->max_weights_size);
475*4bdc9457SAndroid Build Coastguard Worker       }
476*4bdc9457SAndroid Build Coastguard Worker       break;
477*4bdc9457SAndroid Build Coastguard Worker     }
478*4bdc9457SAndroid Build Coastguard Worker   }
479*4bdc9457SAndroid Build Coastguard Worker 
480*4bdc9457SAndroid Build Coastguard Worker   // Mutex is locked in xnn_reserve_space_in_weights_cache when it returns non-NULL, i.e. when cache is not finalized,
481*4bdc9457SAndroid Build Coastguard Worker   // or if it is xnn_cache_state_soft_finalized and has sufficient space.
482*4bdc9457SAndroid Build Coastguard Worker   const enum xnn_status status = xnn_mutex_unlock(&cache->mutex);
483*4bdc9457SAndroid Build Coastguard Worker   (void) status;
484*4bdc9457SAndroid Build Coastguard Worker   assert(status == xnn_status_success);
485*4bdc9457SAndroid Build Coastguard Worker   return offset;
486*4bdc9457SAndroid Build Coastguard Worker }
487*4bdc9457SAndroid Build Coastguard Worker 
xnn_weights_cache_is_finalized(struct xnn_weights_cache * cache)488*4bdc9457SAndroid Build Coastguard Worker bool xnn_weights_cache_is_finalized(struct xnn_weights_cache* cache) {
489*4bdc9457SAndroid Build Coastguard Worker   return cache->finalization_state != xnn_cache_state_not_finalized;
490*4bdc9457SAndroid Build Coastguard Worker }
491