xref: /aosp_15_r20/external/XNNPACK/test/weights-cache.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> // For std::rotate.
9*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>   // For uintptr_t.
10*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>   // For uintptr_t.
11*4bdc9457SAndroid Build Coastguard Worker #include <cstring>   // For memcpy.
12*4bdc9457SAndroid Build Coastguard Worker #include <cstring>   // For memcpy.
13*4bdc9457SAndroid Build Coastguard Worker #include <thread>   // For memcpy.
14*4bdc9457SAndroid Build Coastguard Worker 
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
19*4bdc9457SAndroid Build Coastguard Worker 
cache_end(const xnn_weights_cache * cache)20*4bdc9457SAndroid Build Coastguard Worker static void* cache_end(const xnn_weights_cache* cache) {
21*4bdc9457SAndroid Build Coastguard Worker   return reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(cache->cache.weights.start) + cache->cache.weights.size);
22*4bdc9457SAndroid Build Coastguard Worker }
23*4bdc9457SAndroid Build Coastguard Worker 
write_weights(xnn_weights_cache * cache,const std::string & str)24*4bdc9457SAndroid Build Coastguard Worker static void write_weights(xnn_weights_cache* cache, const std::string& str) {
25*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(nullptr, xnn_reserve_space_in_weights_cache(cache, str.length()));
26*4bdc9457SAndroid Build Coastguard Worker   std::memcpy(cache_end(cache), str.data(), str.length());
27*4bdc9457SAndroid Build Coastguard Worker };
28*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,init_and_release)29*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, init_and_release)
30*4bdc9457SAndroid Build Coastguard Worker {
31*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
32*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
33*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
34*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
35*4bdc9457SAndroid Build Coastguard Worker }
36*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,init_with_size_and_release)37*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, init_with_size_and_release)
38*4bdc9457SAndroid Build Coastguard Worker {
39*4bdc9457SAndroid Build Coastguard Worker   constexpr size_t four_mb = 4194304;
40*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
41*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
42*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_init_weights_cache_with_size(&cache, four_mb));
43*4bdc9457SAndroid Build Coastguard Worker   // Allocation can be rounded up to alignment, so check GE instead of EQ.
44*4bdc9457SAndroid Build Coastguard Worker   ASSERT_GE(cache.cache.weights.capacity, four_mb);
45*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
46*4bdc9457SAndroid Build Coastguard Worker }
47*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,release_null)48*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, release_null)
49*4bdc9457SAndroid Build Coastguard Worker {
50*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(nullptr));
51*4bdc9457SAndroid Build Coastguard Worker }
52*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,get_or_insert)53*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, get_or_insert)
54*4bdc9457SAndroid Build Coastguard Worker {
55*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
56*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
57*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
58*4bdc9457SAndroid Build Coastguard Worker 
59*4bdc9457SAndroid Build Coastguard Worker   write_weights(&cache, "1234");
60*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, cache.cache.weights.start, 4));
61*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, cache.cache.hits);
62*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(1, cache.cache.misses);
63*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, cache.cache.weights.size);
64*4bdc9457SAndroid Build Coastguard Worker 
65*4bdc9457SAndroid Build Coastguard Worker   void* span2_weights = cache_end(&cache);
66*4bdc9457SAndroid Build Coastguard Worker   // Simulate a cache hit.
67*4bdc9457SAndroid Build Coastguard Worker   write_weights(&cache, "1234");
68*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, span2_weights, 4));
69*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(1, cache.cache.hits);
70*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(1, cache.cache.misses);
71*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, cache.cache.weights.size);
72*4bdc9457SAndroid Build Coastguard Worker 
73*4bdc9457SAndroid Build Coastguard Worker   void* span3_weights = cache_end(&cache);
74*4bdc9457SAndroid Build Coastguard Worker   // Simulate a cache miss.
75*4bdc9457SAndroid Build Coastguard Worker   write_weights(&cache, "5678");
76*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, xnn_get_or_insert_weights_cache(&cache, span3_weights, 4));
77*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(1, cache.cache.hits);
78*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(2, cache.cache.misses);
79*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(2, cache.cache.num_entries);
80*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(8, cache.cache.weights.size);
81*4bdc9457SAndroid Build Coastguard Worker 
82*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
83*4bdc9457SAndroid Build Coastguard Worker }
84*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,grow)85*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, grow) {
86*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
87*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_cache cache;
88*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
89*4bdc9457SAndroid Build Coastguard Worker   size_t old_num_buckets = cache.cache.num_buckets;
90*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0, expected_offset = 0; i < old_num_buckets; i++) {
91*4bdc9457SAndroid Build Coastguard Worker     // Add many entries to force cache to grow.
92*4bdc9457SAndroid Build Coastguard Worker     const std::string s = std::to_string(i);
93*4bdc9457SAndroid Build Coastguard Worker     write_weights(&cache, s);
94*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(expected_offset, xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), s.length()));
95*4bdc9457SAndroid Build Coastguard Worker     expected_offset += s.length();
96*4bdc9457SAndroid Build Coastguard Worker   }
97*4bdc9457SAndroid Build Coastguard Worker 
98*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, cache.cache.hits);
99*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(old_num_buckets, cache.cache.num_entries);
100*4bdc9457SAndroid Build Coastguard Worker   // Check that cache has grown.
101*4bdc9457SAndroid Build Coastguard Worker   ASSERT_LT(old_num_buckets, cache.cache.num_buckets);
102*4bdc9457SAndroid Build Coastguard Worker   // Check that all the entries are still in cache.
103*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0, expected_offset = 0; i < old_num_buckets; i++) {
104*4bdc9457SAndroid Build Coastguard Worker     const std::string s = std::to_string(i);
105*4bdc9457SAndroid Build Coastguard Worker     write_weights(&cache, s);
106*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(expected_offset, xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), s.length()));
107*4bdc9457SAndroid Build Coastguard Worker     expected_offset += s.length();
108*4bdc9457SAndroid Build Coastguard Worker   }
109*4bdc9457SAndroid Build Coastguard Worker   // And now all of the lookups should be cache hits.
110*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(old_num_buckets, cache.cache.hits);
111*4bdc9457SAndroid Build Coastguard Worker 
112*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
113*4bdc9457SAndroid Build Coastguard Worker }
114*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_MEMORY,allocate_and_release)115*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_MEMORY, allocate_and_release) {
116*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_buffer b;
117*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, XNN_DEFAULT_WEIGHTS_BUFFER_SIZE));
118*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
119*4bdc9457SAndroid Build Coastguard Worker }
120*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_MEMORY,grow)121*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_MEMORY, grow) {
122*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_buffer b;
123*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, 8));
124*4bdc9457SAndroid Build Coastguard Worker   // Allocations rounded to page size, so it might not be 8.
125*4bdc9457SAndroid Build Coastguard Worker   size_t old_capacity = b.capacity;
126*4bdc9457SAndroid Build Coastguard Worker 
127*4bdc9457SAndroid Build Coastguard Worker   std::string junk = "1234";
128*4bdc9457SAndroid Build Coastguard Worker   std::memcpy(b.start, junk.data(), junk.length());
129*4bdc9457SAndroid Build Coastguard Worker   b.size += junk.length();
130*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(b.size, 4);
131*4bdc9457SAndroid Build Coastguard Worker   const uintptr_t old_weights = reinterpret_cast<uintptr_t>(b.start);
132*4bdc9457SAndroid Build Coastguard Worker 
133*4bdc9457SAndroid Build Coastguard Worker   // This should be a no-op, since we have enough space.
134*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_reserve_weights_memory(&b, 4));
135*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(old_weights, reinterpret_cast<uintptr_t>(b.start));
136*4bdc9457SAndroid Build Coastguard Worker 
137*4bdc9457SAndroid Build Coastguard Worker   // Simulate copying bytes until we are full.
138*4bdc9457SAndroid Build Coastguard Worker   b.size += (old_capacity - b.size);
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker   const size_t old_size = b.size;
141*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_reserve_weights_memory(&b, 4));
142*4bdc9457SAndroid Build Coastguard Worker 
143*4bdc9457SAndroid Build Coastguard Worker   // After growing, the new capacity should be bigger than the old one.
144*4bdc9457SAndroid Build Coastguard Worker   ASSERT_LT(old_capacity, b.capacity);
145*4bdc9457SAndroid Build Coastguard Worker   // At least 4 bytes free.
146*4bdc9457SAndroid Build Coastguard Worker   ASSERT_GE(b.capacity, b.size + 4);
147*4bdc9457SAndroid Build Coastguard Worker   // But size stays the same.
148*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(old_size, b.size);
149*4bdc9457SAndroid Build Coastguard Worker 
150*4bdc9457SAndroid Build Coastguard Worker   // Check that after growing, the contents remain.
151*4bdc9457SAndroid Build Coastguard Worker   std::string actual = std::string(static_cast<char*>(b.start), static_cast<char*>(b.start) + junk.length());
152*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(junk, actual);
153*4bdc9457SAndroid Build Coastguard Worker 
154*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
155*4bdc9457SAndroid Build Coastguard Worker }
156*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,finalize_empty)157*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize_empty) {
158*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_buffer b;
159*4bdc9457SAndroid Build Coastguard Worker   const size_t initial_capacity = 1024 * 1024;  // 1MB.
160*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, initial_capacity));
161*4bdc9457SAndroid Build Coastguard Worker 
162*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, b.size);
163*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(initial_capacity, b.capacity);
164*4bdc9457SAndroid Build Coastguard Worker 
165*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
166*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, b.size);
167*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, b.capacity);
168*4bdc9457SAndroid Build Coastguard Worker 
169*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
170*4bdc9457SAndroid Build Coastguard Worker }
171*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,finalize)172*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize) {
173*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_buffer b;
174*4bdc9457SAndroid Build Coastguard Worker   const size_t initial_capacity = 1024 * 1024;  // 1MB.
175*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, initial_capacity));
176*4bdc9457SAndroid Build Coastguard Worker   const size_t actual_capacity = b.capacity;
177*4bdc9457SAndroid Build Coastguard Worker 
178*4bdc9457SAndroid Build Coastguard Worker   const std::string junk = "1234";
179*4bdc9457SAndroid Build Coastguard Worker   std::memcpy(b.start, junk.data(), junk.length());
180*4bdc9457SAndroid Build Coastguard Worker   b.size += junk.length();
181*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, b.size);
182*4bdc9457SAndroid Build Coastguard Worker 
183*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
184*4bdc9457SAndroid Build Coastguard Worker   #if XNN_PLATFORM_WEB
185*4bdc9457SAndroid Build Coastguard Worker     // Web does not support partial unmapping.
186*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(actual_capacity, b.capacity);
187*4bdc9457SAndroid Build Coastguard Worker   #else
188*4bdc9457SAndroid Build Coastguard Worker     // The actual capacity depends on page size, since it is aligned, just check that it shrunk.
189*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(actual_capacity, b.capacity);
190*4bdc9457SAndroid Build Coastguard Worker   #endif
191*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, b.size);
192*4bdc9457SAndroid Build Coastguard Worker 
193*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
194*4bdc9457SAndroid Build Coastguard Worker }
195*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,finalize_twice)196*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize_twice) {
197*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_buffer b;
198*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, XNN_DEFAULT_WEIGHTS_BUFFER_SIZE));
199*4bdc9457SAndroid Build Coastguard Worker 
200*4bdc9457SAndroid Build Coastguard Worker   const std::string junk = "1234";
201*4bdc9457SAndroid Build Coastguard Worker   std::memcpy(b.start, junk.data(), junk.length());
202*4bdc9457SAndroid Build Coastguard Worker   b.size += junk.length();
203*4bdc9457SAndroid Build Coastguard Worker 
204*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
205*4bdc9457SAndroid Build Coastguard Worker   const size_t capacity = b.capacity;
206*4bdc9457SAndroid Build Coastguard Worker   // Finalizing twice does not error.
207*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
208*4bdc9457SAndroid Build Coastguard Worker   // Capacity does not change.
209*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(capacity, b.capacity);
210*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, b.size);
211*4bdc9457SAndroid Build Coastguard Worker 
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
213*4bdc9457SAndroid Build Coastguard Worker }
214*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,finalize_capacity_smaller_than_page_aligned_size)215*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize_capacity_smaller_than_page_aligned_size) {
216*4bdc9457SAndroid Build Coastguard Worker   xnn_weights_buffer b;
217*4bdc9457SAndroid Build Coastguard Worker   // Small capacity that is smaller than page sizes on all platforms.
218*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, 8));
219*4bdc9457SAndroid Build Coastguard Worker 
220*4bdc9457SAndroid Build Coastguard Worker   const std::string junk = "1234";
221*4bdc9457SAndroid Build Coastguard Worker   std::memcpy(b.start, junk.data(), junk.length());
222*4bdc9457SAndroid Build Coastguard Worker   b.size += junk.length();
223*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
224*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, b.size);
225*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
226*4bdc9457SAndroid Build Coastguard Worker }
227*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,write_many_cache_hits)228*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, write_many_cache_hits) {
229*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_WEB && !defined(__EMSCRIPTEN_PTHREADS__)
230*4bdc9457SAndroid Build Coastguard Worker   GTEST_SKIP();
231*4bdc9457SAndroid Build Coastguard Worker #endif
232*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
233*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
234*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
235*4bdc9457SAndroid Build Coastguard Worker   const std::string weights = "0123456789abcdefghij";
236*4bdc9457SAndroid Build Coastguard Worker   const size_t weights_size = weights.size();
237*4bdc9457SAndroid Build Coastguard Worker   auto write = [&] {
238*4bdc9457SAndroid Build Coastguard Worker     write_weights(&cache, weights);
239*4bdc9457SAndroid Build Coastguard Worker     xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), weights_size);
240*4bdc9457SAndroid Build Coastguard Worker   };
241*4bdc9457SAndroid Build Coastguard Worker   constexpr size_t num_threads = 20;
242*4bdc9457SAndroid Build Coastguard Worker   std::vector<std::thread> threads;
243*4bdc9457SAndroid Build Coastguard Worker   threads.reserve(num_threads);
244*4bdc9457SAndroid Build Coastguard Worker 
245*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < num_threads; i++) {
246*4bdc9457SAndroid Build Coastguard Worker     threads.emplace_back(write);
247*4bdc9457SAndroid Build Coastguard Worker   }
248*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < num_threads; i++) {
249*4bdc9457SAndroid Build Coastguard Worker     threads[i].join();
250*4bdc9457SAndroid Build Coastguard Worker   }
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(num_threads - 1, cache.cache.hits);
253*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(1, cache.cache.num_entries);
254*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(weights_size, cache.cache.weights.size);
255*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
256*4bdc9457SAndroid Build Coastguard Worker }
257*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,write_many_cache_misses)258*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, write_many_cache_misses) {
259*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_WEB && !defined(__EMSCRIPTEN_PTHREADS__)
260*4bdc9457SAndroid Build Coastguard Worker   GTEST_SKIP();
261*4bdc9457SAndroid Build Coastguard Worker #endif
262*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
263*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
264*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
265*4bdc9457SAndroid Build Coastguard Worker   const std::string weights = "0123456789abcdefghij";
266*4bdc9457SAndroid Build Coastguard Worker   const size_t weights_size = weights.size();
267*4bdc9457SAndroid Build Coastguard Worker   auto write = [&](size_t i) {
268*4bdc9457SAndroid Build Coastguard Worker     std::string rotated_weights = weights;
269*4bdc9457SAndroid Build Coastguard Worker     std::rotate(rotated_weights.begin(), rotated_weights.begin() + i,
270*4bdc9457SAndroid Build Coastguard Worker                 rotated_weights.end());
271*4bdc9457SAndroid Build Coastguard Worker     write_weights(&cache, rotated_weights);
272*4bdc9457SAndroid Build Coastguard Worker     xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), weights_size);
273*4bdc9457SAndroid Build Coastguard Worker   };
274*4bdc9457SAndroid Build Coastguard Worker   constexpr size_t num_threads = 20;
275*4bdc9457SAndroid Build Coastguard Worker   ASSERT_LE(num_threads, weights_size);
276*4bdc9457SAndroid Build Coastguard Worker   std::vector<std::thread> threads;
277*4bdc9457SAndroid Build Coastguard Worker   threads.reserve(num_threads);
278*4bdc9457SAndroid Build Coastguard Worker 
279*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < num_threads; i++) {
280*4bdc9457SAndroid Build Coastguard Worker     threads.emplace_back(write, i);
281*4bdc9457SAndroid Build Coastguard Worker   }
282*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < num_threads; i++) {
283*4bdc9457SAndroid Build Coastguard Worker     threads[i].join();
284*4bdc9457SAndroid Build Coastguard Worker   }
285*4bdc9457SAndroid Build Coastguard Worker 
286*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, cache.cache.hits);
287*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(num_threads, cache.cache.num_entries);
288*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(weights_size * num_threads, cache.cache.weights.size);
289*4bdc9457SAndroid Build Coastguard Worker   EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
290*4bdc9457SAndroid Build Coastguard Worker }
291*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,operations_on_finalized_cache_hard)292*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, operations_on_finalized_cache_hard) {
293*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
294*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
295*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
296*4bdc9457SAndroid Build Coastguard Worker 
297*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_hard));
298*4bdc9457SAndroid Build Coastguard Worker   // Finalizing a finalized cache is an error.
299*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_hard));
300*4bdc9457SAndroid Build Coastguard Worker   // Trying to reserve is an error.
301*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(nullptr, xnn_reserve_space_in_weights_cache(&cache, 1));
302*4bdc9457SAndroid Build Coastguard Worker 
303*4bdc9457SAndroid Build Coastguard Worker   // We should not be able to insert into the weights cache, and also this shouldn't timeout by unlocking a mutex which
304*4bdc9457SAndroid Build Coastguard Worker   // has not been locked (since xnn_reserve_space_in_weights_cache above failed).
305*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(XNN_CACHE_NOT_FOUND, xnn_get_or_insert_weights_cache(&cache, cache.cache.weights.start, 4));
306*4bdc9457SAndroid Build Coastguard Worker 
307*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
308*4bdc9457SAndroid Build Coastguard Worker }
309*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,operations_on_finalized_cache_soft)310*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, operations_on_finalized_cache_soft) {
311*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
312*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
313*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
314*4bdc9457SAndroid Build Coastguard Worker 
315*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_soft));
316*4bdc9457SAndroid Build Coastguard Worker   // Finalizing a finalized cache is an error.
317*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_soft));
318*4bdc9457SAndroid Build Coastguard Worker   // Trying to reserve too much is an error.
319*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(nullptr, xnn_reserve_space_in_weights_cache(&cache, cache.cache.weights.capacity + 1));
320*4bdc9457SAndroid Build Coastguard Worker 
321*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
322*4bdc9457SAndroid Build Coastguard Worker }
323*4bdc9457SAndroid Build Coastguard Worker 
TEST(WEIGHTS_CACHE,insert_into_finalized_cache_soft)324*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, insert_into_finalized_cache_soft) {
325*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
326*4bdc9457SAndroid Build Coastguard Worker   struct xnn_weights_cache cache;
327*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
328*4bdc9457SAndroid Build Coastguard Worker 
329*4bdc9457SAndroid Build Coastguard Worker   write_weights(&cache, "1234");
330*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, cache.cache.weights.start, 4));
331*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_soft));
332*4bdc9457SAndroid Build Coastguard Worker 
333*4bdc9457SAndroid Build Coastguard Worker   // Inserting into a finalized cache is okay as long as cache memory has space and it is a cache hit.
334*4bdc9457SAndroid Build Coastguard Worker   ASSERT_LT(cache.cache.weights.size + 4, cache.cache.weights.capacity);
335*4bdc9457SAndroid Build Coastguard Worker   write_weights(&cache, "1234");
336*4bdc9457SAndroid Build Coastguard Worker   void* cached_weights = cache_end(&cache);
337*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, cached_weights, 4));
338*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(4, cache.cache.weights.size);
339*4bdc9457SAndroid Build Coastguard Worker 
340*4bdc9457SAndroid Build Coastguard Worker   // Sufficient space, but Cache miss.
341*4bdc9457SAndroid Build Coastguard Worker   write_weights(&cache, "4567");
342*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(XNN_CACHE_NOT_FOUND, xnn_get_or_insert_weights_cache(&cache, cached_weights, 4));
343*4bdc9457SAndroid Build Coastguard Worker 
344*4bdc9457SAndroid Build Coastguard Worker   // Not enough space in the finalized weights cache.
345*4bdc9457SAndroid Build Coastguard Worker   std::string big_string(cache.cache.weights.capacity, '5');
346*4bdc9457SAndroid Build Coastguard Worker   // Don't use write_weights here as it asserts xnn_reserve_space_in_weights_cache does not return nullptr.
347*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(nullptr, xnn_reserve_space_in_weights_cache(&cache, big_string.length()));
348*4bdc9457SAndroid Build Coastguard Worker 
349*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
350*4bdc9457SAndroid Build Coastguard Worker }
351