1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
7*4bdc9457SAndroid Build Coastguard Worker
8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> // For std::rotate.
9*4bdc9457SAndroid Build Coastguard Worker #include <cstdint> // For uintptr_t.
10*4bdc9457SAndroid Build Coastguard Worker #include <cstdint> // For uintptr_t.
11*4bdc9457SAndroid Build Coastguard Worker #include <cstring> // For memcpy.
12*4bdc9457SAndroid Build Coastguard Worker #include <cstring> // For memcpy.
13*4bdc9457SAndroid Build Coastguard Worker #include <thread> // For memcpy.
14*4bdc9457SAndroid Build Coastguard Worker
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
19*4bdc9457SAndroid Build Coastguard Worker
cache_end(const xnn_weights_cache * cache)20*4bdc9457SAndroid Build Coastguard Worker static void* cache_end(const xnn_weights_cache* cache) {
21*4bdc9457SAndroid Build Coastguard Worker return reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(cache->cache.weights.start) + cache->cache.weights.size);
22*4bdc9457SAndroid Build Coastguard Worker }
23*4bdc9457SAndroid Build Coastguard Worker
write_weights(xnn_weights_cache * cache,const std::string & str)24*4bdc9457SAndroid Build Coastguard Worker static void write_weights(xnn_weights_cache* cache, const std::string& str) {
25*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, xnn_reserve_space_in_weights_cache(cache, str.length()));
26*4bdc9457SAndroid Build Coastguard Worker std::memcpy(cache_end(cache), str.data(), str.length());
27*4bdc9457SAndroid Build Coastguard Worker };
28*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,init_and_release)29*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, init_and_release)
30*4bdc9457SAndroid Build Coastguard Worker {
31*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
32*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
33*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
34*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
35*4bdc9457SAndroid Build Coastguard Worker }
36*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,init_with_size_and_release)37*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, init_with_size_and_release)
38*4bdc9457SAndroid Build Coastguard Worker {
39*4bdc9457SAndroid Build Coastguard Worker constexpr size_t four_mb = 4194304;
40*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
41*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
42*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_init_weights_cache_with_size(&cache, four_mb));
43*4bdc9457SAndroid Build Coastguard Worker // Allocation can be rounded up to alignment, so check GE instead of EQ.
44*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cache.cache.weights.capacity, four_mb);
45*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
46*4bdc9457SAndroid Build Coastguard Worker }
47*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,release_null)48*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, release_null)
49*4bdc9457SAndroid Build Coastguard Worker {
50*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(nullptr));
51*4bdc9457SAndroid Build Coastguard Worker }
52*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,get_or_insert)53*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, get_or_insert)
54*4bdc9457SAndroid Build Coastguard Worker {
55*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
56*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
57*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
58*4bdc9457SAndroid Build Coastguard Worker
59*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, "1234");
60*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, cache.cache.weights.start, 4));
61*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, cache.cache.hits);
62*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(1, cache.cache.misses);
63*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, cache.cache.weights.size);
64*4bdc9457SAndroid Build Coastguard Worker
65*4bdc9457SAndroid Build Coastguard Worker void* span2_weights = cache_end(&cache);
66*4bdc9457SAndroid Build Coastguard Worker // Simulate a cache hit.
67*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, "1234");
68*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, span2_weights, 4));
69*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(1, cache.cache.hits);
70*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(1, cache.cache.misses);
71*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, cache.cache.weights.size);
72*4bdc9457SAndroid Build Coastguard Worker
73*4bdc9457SAndroid Build Coastguard Worker void* span3_weights = cache_end(&cache);
74*4bdc9457SAndroid Build Coastguard Worker // Simulate a cache miss.
75*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, "5678");
76*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, xnn_get_or_insert_weights_cache(&cache, span3_weights, 4));
77*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(1, cache.cache.hits);
78*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(2, cache.cache.misses);
79*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(2, cache.cache.num_entries);
80*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(8, cache.cache.weights.size);
81*4bdc9457SAndroid Build Coastguard Worker
82*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
83*4bdc9457SAndroid Build Coastguard Worker }
84*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,grow)85*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, grow) {
86*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
87*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache cache;
88*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
89*4bdc9457SAndroid Build Coastguard Worker size_t old_num_buckets = cache.cache.num_buckets;
90*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0, expected_offset = 0; i < old_num_buckets; i++) {
91*4bdc9457SAndroid Build Coastguard Worker // Add many entries to force cache to grow.
92*4bdc9457SAndroid Build Coastguard Worker const std::string s = std::to_string(i);
93*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, s);
94*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected_offset, xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), s.length()));
95*4bdc9457SAndroid Build Coastguard Worker expected_offset += s.length();
96*4bdc9457SAndroid Build Coastguard Worker }
97*4bdc9457SAndroid Build Coastguard Worker
98*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, cache.cache.hits);
99*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_num_buckets, cache.cache.num_entries);
100*4bdc9457SAndroid Build Coastguard Worker // Check that cache has grown.
101*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(old_num_buckets, cache.cache.num_buckets);
102*4bdc9457SAndroid Build Coastguard Worker // Check that all the entries are still in cache.
103*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0, expected_offset = 0; i < old_num_buckets; i++) {
104*4bdc9457SAndroid Build Coastguard Worker const std::string s = std::to_string(i);
105*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, s);
106*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(expected_offset, xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), s.length()));
107*4bdc9457SAndroid Build Coastguard Worker expected_offset += s.length();
108*4bdc9457SAndroid Build Coastguard Worker }
109*4bdc9457SAndroid Build Coastguard Worker // And now all of the lookups should be cache hits.
110*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_num_buckets, cache.cache.hits);
111*4bdc9457SAndroid Build Coastguard Worker
112*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
113*4bdc9457SAndroid Build Coastguard Worker }
114*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_MEMORY,allocate_and_release)115*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_MEMORY, allocate_and_release) {
116*4bdc9457SAndroid Build Coastguard Worker xnn_weights_buffer b;
117*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, XNN_DEFAULT_WEIGHTS_BUFFER_SIZE));
118*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
119*4bdc9457SAndroid Build Coastguard Worker }
120*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_MEMORY,grow)121*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_MEMORY, grow) {
122*4bdc9457SAndroid Build Coastguard Worker xnn_weights_buffer b;
123*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, 8));
124*4bdc9457SAndroid Build Coastguard Worker // Allocations rounded to page size, so it might not be 8.
125*4bdc9457SAndroid Build Coastguard Worker size_t old_capacity = b.capacity;
126*4bdc9457SAndroid Build Coastguard Worker
127*4bdc9457SAndroid Build Coastguard Worker std::string junk = "1234";
128*4bdc9457SAndroid Build Coastguard Worker std::memcpy(b.start, junk.data(), junk.length());
129*4bdc9457SAndroid Build Coastguard Worker b.size += junk.length();
130*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(b.size, 4);
131*4bdc9457SAndroid Build Coastguard Worker const uintptr_t old_weights = reinterpret_cast<uintptr_t>(b.start);
132*4bdc9457SAndroid Build Coastguard Worker
133*4bdc9457SAndroid Build Coastguard Worker // This should be a no-op, since we have enough space.
134*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_reserve_weights_memory(&b, 4));
135*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_weights, reinterpret_cast<uintptr_t>(b.start));
136*4bdc9457SAndroid Build Coastguard Worker
137*4bdc9457SAndroid Build Coastguard Worker // Simulate copying bytes until we are full.
138*4bdc9457SAndroid Build Coastguard Worker b.size += (old_capacity - b.size);
139*4bdc9457SAndroid Build Coastguard Worker
140*4bdc9457SAndroid Build Coastguard Worker const size_t old_size = b.size;
141*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_reserve_weights_memory(&b, 4));
142*4bdc9457SAndroid Build Coastguard Worker
143*4bdc9457SAndroid Build Coastguard Worker // After growing, the new capacity should be bigger than the old one.
144*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(old_capacity, b.capacity);
145*4bdc9457SAndroid Build Coastguard Worker // At least 4 bytes free.
146*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(b.capacity, b.size + 4);
147*4bdc9457SAndroid Build Coastguard Worker // But size stays the same.
148*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_size, b.size);
149*4bdc9457SAndroid Build Coastguard Worker
150*4bdc9457SAndroid Build Coastguard Worker // Check that after growing, the contents remain.
151*4bdc9457SAndroid Build Coastguard Worker std::string actual = std::string(static_cast<char*>(b.start), static_cast<char*>(b.start) + junk.length());
152*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(junk, actual);
153*4bdc9457SAndroid Build Coastguard Worker
154*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
155*4bdc9457SAndroid Build Coastguard Worker }
156*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,finalize_empty)157*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize_empty) {
158*4bdc9457SAndroid Build Coastguard Worker xnn_weights_buffer b;
159*4bdc9457SAndroid Build Coastguard Worker const size_t initial_capacity = 1024 * 1024; // 1MB.
160*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, initial_capacity));
161*4bdc9457SAndroid Build Coastguard Worker
162*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, b.size);
163*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(initial_capacity, b.capacity);
164*4bdc9457SAndroid Build Coastguard Worker
165*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, b.size);
167*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, b.capacity);
168*4bdc9457SAndroid Build Coastguard Worker
169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
170*4bdc9457SAndroid Build Coastguard Worker }
171*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,finalize)172*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize) {
173*4bdc9457SAndroid Build Coastguard Worker xnn_weights_buffer b;
174*4bdc9457SAndroid Build Coastguard Worker const size_t initial_capacity = 1024 * 1024; // 1MB.
175*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, initial_capacity));
176*4bdc9457SAndroid Build Coastguard Worker const size_t actual_capacity = b.capacity;
177*4bdc9457SAndroid Build Coastguard Worker
178*4bdc9457SAndroid Build Coastguard Worker const std::string junk = "1234";
179*4bdc9457SAndroid Build Coastguard Worker std::memcpy(b.start, junk.data(), junk.length());
180*4bdc9457SAndroid Build Coastguard Worker b.size += junk.length();
181*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, b.size);
182*4bdc9457SAndroid Build Coastguard Worker
183*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
184*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_WEB
185*4bdc9457SAndroid Build Coastguard Worker // Web does not support partial unmapping.
186*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(actual_capacity, b.capacity);
187*4bdc9457SAndroid Build Coastguard Worker #else
188*4bdc9457SAndroid Build Coastguard Worker // The actual capacity depends on page size, since it is aligned, just check that it shrunk.
189*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(actual_capacity, b.capacity);
190*4bdc9457SAndroid Build Coastguard Worker #endif
191*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, b.size);
192*4bdc9457SAndroid Build Coastguard Worker
193*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
194*4bdc9457SAndroid Build Coastguard Worker }
195*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,finalize_twice)196*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize_twice) {
197*4bdc9457SAndroid Build Coastguard Worker xnn_weights_buffer b;
198*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, XNN_DEFAULT_WEIGHTS_BUFFER_SIZE));
199*4bdc9457SAndroid Build Coastguard Worker
200*4bdc9457SAndroid Build Coastguard Worker const std::string junk = "1234";
201*4bdc9457SAndroid Build Coastguard Worker std::memcpy(b.start, junk.data(), junk.length());
202*4bdc9457SAndroid Build Coastguard Worker b.size += junk.length();
203*4bdc9457SAndroid Build Coastguard Worker
204*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
205*4bdc9457SAndroid Build Coastguard Worker const size_t capacity = b.capacity;
206*4bdc9457SAndroid Build Coastguard Worker // Finalizing twice does not error.
207*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
208*4bdc9457SAndroid Build Coastguard Worker // Capacity does not change.
209*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(capacity, b.capacity);
210*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, b.size);
211*4bdc9457SAndroid Build Coastguard Worker
212*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
213*4bdc9457SAndroid Build Coastguard Worker }
214*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,finalize_capacity_smaller_than_page_aligned_size)215*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, finalize_capacity_smaller_than_page_aligned_size) {
216*4bdc9457SAndroid Build Coastguard Worker xnn_weights_buffer b;
217*4bdc9457SAndroid Build Coastguard Worker // Small capacity that is smaller than page sizes on all platforms.
218*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_weights_memory(&b, 8));
219*4bdc9457SAndroid Build Coastguard Worker
220*4bdc9457SAndroid Build Coastguard Worker const std::string junk = "1234";
221*4bdc9457SAndroid Build Coastguard Worker std::memcpy(b.start, junk.data(), junk.length());
222*4bdc9457SAndroid Build Coastguard Worker b.size += junk.length();
223*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_memory(&b));
224*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, b.size);
225*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_memory(&b));
226*4bdc9457SAndroid Build Coastguard Worker }
227*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,write_many_cache_hits)228*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, write_many_cache_hits) {
229*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_WEB && !defined(__EMSCRIPTEN_PTHREADS__)
230*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
231*4bdc9457SAndroid Build Coastguard Worker #endif
232*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
233*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
234*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
235*4bdc9457SAndroid Build Coastguard Worker const std::string weights = "0123456789abcdefghij";
236*4bdc9457SAndroid Build Coastguard Worker const size_t weights_size = weights.size();
237*4bdc9457SAndroid Build Coastguard Worker auto write = [&] {
238*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, weights);
239*4bdc9457SAndroid Build Coastguard Worker xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), weights_size);
240*4bdc9457SAndroid Build Coastguard Worker };
241*4bdc9457SAndroid Build Coastguard Worker constexpr size_t num_threads = 20;
242*4bdc9457SAndroid Build Coastguard Worker std::vector<std::thread> threads;
243*4bdc9457SAndroid Build Coastguard Worker threads.reserve(num_threads);
244*4bdc9457SAndroid Build Coastguard Worker
245*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_threads; i++) {
246*4bdc9457SAndroid Build Coastguard Worker threads.emplace_back(write);
247*4bdc9457SAndroid Build Coastguard Worker }
248*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_threads; i++) {
249*4bdc9457SAndroid Build Coastguard Worker threads[i].join();
250*4bdc9457SAndroid Build Coastguard Worker }
251*4bdc9457SAndroid Build Coastguard Worker
252*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_threads - 1, cache.cache.hits);
253*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(1, cache.cache.num_entries);
254*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_size, cache.cache.weights.size);
255*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
256*4bdc9457SAndroid Build Coastguard Worker }
257*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,write_many_cache_misses)258*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, write_many_cache_misses) {
259*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_WEB && !defined(__EMSCRIPTEN_PTHREADS__)
260*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
261*4bdc9457SAndroid Build Coastguard Worker #endif
262*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
263*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
264*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
265*4bdc9457SAndroid Build Coastguard Worker const std::string weights = "0123456789abcdefghij";
266*4bdc9457SAndroid Build Coastguard Worker const size_t weights_size = weights.size();
267*4bdc9457SAndroid Build Coastguard Worker auto write = [&](size_t i) {
268*4bdc9457SAndroid Build Coastguard Worker std::string rotated_weights = weights;
269*4bdc9457SAndroid Build Coastguard Worker std::rotate(rotated_weights.begin(), rotated_weights.begin() + i,
270*4bdc9457SAndroid Build Coastguard Worker rotated_weights.end());
271*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, rotated_weights);
272*4bdc9457SAndroid Build Coastguard Worker xnn_get_or_insert_weights_cache(&cache, cache_end(&cache), weights_size);
273*4bdc9457SAndroid Build Coastguard Worker };
274*4bdc9457SAndroid Build Coastguard Worker constexpr size_t num_threads = 20;
275*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(num_threads, weights_size);
276*4bdc9457SAndroid Build Coastguard Worker std::vector<std::thread> threads;
277*4bdc9457SAndroid Build Coastguard Worker threads.reserve(num_threads);
278*4bdc9457SAndroid Build Coastguard Worker
279*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_threads; i++) {
280*4bdc9457SAndroid Build Coastguard Worker threads.emplace_back(write, i);
281*4bdc9457SAndroid Build Coastguard Worker }
282*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_threads; i++) {
283*4bdc9457SAndroid Build Coastguard Worker threads[i].join();
284*4bdc9457SAndroid Build Coastguard Worker }
285*4bdc9457SAndroid Build Coastguard Worker
286*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, cache.cache.hits);
287*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(num_threads, cache.cache.num_entries);
288*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_size * num_threads, cache.cache.weights.size);
289*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
290*4bdc9457SAndroid Build Coastguard Worker }
291*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,operations_on_finalized_cache_hard)292*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, operations_on_finalized_cache_hard) {
293*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
294*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
295*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
296*4bdc9457SAndroid Build Coastguard Worker
297*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_hard));
298*4bdc9457SAndroid Build Coastguard Worker // Finalizing a finalized cache is an error.
299*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_hard));
300*4bdc9457SAndroid Build Coastguard Worker // Trying to reserve is an error.
301*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(nullptr, xnn_reserve_space_in_weights_cache(&cache, 1));
302*4bdc9457SAndroid Build Coastguard Worker
303*4bdc9457SAndroid Build Coastguard Worker // We should not be able to insert into the weights cache, and also this shouldn't timeout by unlocking a mutex which
304*4bdc9457SAndroid Build Coastguard Worker // has not been locked (since xnn_reserve_space_in_weights_cache above failed).
305*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(XNN_CACHE_NOT_FOUND, xnn_get_or_insert_weights_cache(&cache, cache.cache.weights.start, 4));
306*4bdc9457SAndroid Build Coastguard Worker
307*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
308*4bdc9457SAndroid Build Coastguard Worker }
309*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,operations_on_finalized_cache_soft)310*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, operations_on_finalized_cache_soft) {
311*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
312*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
313*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
314*4bdc9457SAndroid Build Coastguard Worker
315*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_soft));
316*4bdc9457SAndroid Build Coastguard Worker // Finalizing a finalized cache is an error.
317*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_soft));
318*4bdc9457SAndroid Build Coastguard Worker // Trying to reserve too much is an error.
319*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(nullptr, xnn_reserve_space_in_weights_cache(&cache, cache.cache.weights.capacity + 1));
320*4bdc9457SAndroid Build Coastguard Worker
321*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
322*4bdc9457SAndroid Build Coastguard Worker }
323*4bdc9457SAndroid Build Coastguard Worker
TEST(WEIGHTS_CACHE,insert_into_finalized_cache_soft)324*4bdc9457SAndroid Build Coastguard Worker TEST(WEIGHTS_CACHE, insert_into_finalized_cache_soft) {
325*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
326*4bdc9457SAndroid Build Coastguard Worker struct xnn_weights_cache cache;
327*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_init_weights_cache(&cache));
328*4bdc9457SAndroid Build Coastguard Worker
329*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, "1234");
330*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, cache.cache.weights.start, 4));
331*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_weights_cache(&cache, xnn_weights_cache_finalization_kind_soft));
332*4bdc9457SAndroid Build Coastguard Worker
333*4bdc9457SAndroid Build Coastguard Worker // Inserting into a finalized cache is okay as long as cache memory has space and it is a cache hit.
334*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(cache.cache.weights.size + 4, cache.cache.weights.capacity);
335*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, "1234");
336*4bdc9457SAndroid Build Coastguard Worker void* cached_weights = cache_end(&cache);
337*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, xnn_get_or_insert_weights_cache(&cache, cached_weights, 4));
338*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(4, cache.cache.weights.size);
339*4bdc9457SAndroid Build Coastguard Worker
340*4bdc9457SAndroid Build Coastguard Worker // Sufficient space, but Cache miss.
341*4bdc9457SAndroid Build Coastguard Worker write_weights(&cache, "4567");
342*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(XNN_CACHE_NOT_FOUND, xnn_get_or_insert_weights_cache(&cache, cached_weights, 4));
343*4bdc9457SAndroid Build Coastguard Worker
344*4bdc9457SAndroid Build Coastguard Worker // Not enough space in the finalized weights cache.
345*4bdc9457SAndroid Build Coastguard Worker std::string big_string(cache.cache.weights.capacity, '5');
346*4bdc9457SAndroid Build Coastguard Worker // Don't use write_weights here as it asserts xnn_reserve_space_in_weights_cache does not return nullptr.
347*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(nullptr, xnn_reserve_space_in_weights_cache(&cache, big_string.length()));
348*4bdc9457SAndroid Build Coastguard Worker
349*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_weights_cache(&cache));
350*4bdc9457SAndroid Build Coastguard Worker }
351