xref: /aosp_15_r20/external/XNNPACK/test/prelu-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
11*4bdc9457SAndroid Build Coastguard Worker 
12*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
13*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
15*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
16*4bdc9457SAndroid Build Coastguard Worker #include <functional>
17*4bdc9457SAndroid Build Coastguard Worker #include <random>
18*4bdc9457SAndroid Build Coastguard Worker #include <vector>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker class PReLUOperatorTester {
25*4bdc9457SAndroid Build Coastguard Worker  public:
26*4bdc9457SAndroid Build Coastguard Worker   enum class WeightsType {
27*4bdc9457SAndroid Build Coastguard Worker     Default,
28*4bdc9457SAndroid Build Coastguard Worker     FP32,
29*4bdc9457SAndroid Build Coastguard Worker   };
30*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)31*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& batch_size(size_t batch_size) {
32*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
33*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
34*4bdc9457SAndroid Build Coastguard Worker     return *this;
35*4bdc9457SAndroid Build Coastguard Worker   }
36*4bdc9457SAndroid Build Coastguard Worker 
batch_size()37*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
38*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
39*4bdc9457SAndroid Build Coastguard Worker   }
40*4bdc9457SAndroid Build Coastguard Worker 
channels(size_t channels)41*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& channels(size_t channels) {
42*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
43*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
44*4bdc9457SAndroid Build Coastguard Worker     return *this;
45*4bdc9457SAndroid Build Coastguard Worker   }
46*4bdc9457SAndroid Build Coastguard Worker 
channels()47*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
48*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
49*4bdc9457SAndroid Build Coastguard Worker   }
50*4bdc9457SAndroid Build Coastguard Worker 
x_stride(size_t x_stride)51*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& x_stride(size_t x_stride) {
52*4bdc9457SAndroid Build Coastguard Worker     assert(x_stride != 0);
53*4bdc9457SAndroid Build Coastguard Worker     this->x_stride_ = x_stride;
54*4bdc9457SAndroid Build Coastguard Worker     return *this;
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
x_stride()57*4bdc9457SAndroid Build Coastguard Worker   inline size_t x_stride() const {
58*4bdc9457SAndroid Build Coastguard Worker     if (this->x_stride_ == 0) {
59*4bdc9457SAndroid Build Coastguard Worker       return this->channels_;
60*4bdc9457SAndroid Build Coastguard Worker     } else {
61*4bdc9457SAndroid Build Coastguard Worker       assert(this->x_stride_ >= this->channels_);
62*4bdc9457SAndroid Build Coastguard Worker       return this->x_stride_;
63*4bdc9457SAndroid Build Coastguard Worker     }
64*4bdc9457SAndroid Build Coastguard Worker   }
65*4bdc9457SAndroid Build Coastguard Worker 
y_stride(size_t y_stride)66*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& y_stride(size_t y_stride) {
67*4bdc9457SAndroid Build Coastguard Worker     assert(y_stride != 0);
68*4bdc9457SAndroid Build Coastguard Worker     this->y_stride_ = y_stride;
69*4bdc9457SAndroid Build Coastguard Worker     return *this;
70*4bdc9457SAndroid Build Coastguard Worker   }
71*4bdc9457SAndroid Build Coastguard Worker 
y_stride()72*4bdc9457SAndroid Build Coastguard Worker   inline size_t y_stride() const {
73*4bdc9457SAndroid Build Coastguard Worker     if (this->y_stride_ == 0) {
74*4bdc9457SAndroid Build Coastguard Worker       return this->channels_;
75*4bdc9457SAndroid Build Coastguard Worker     } else {
76*4bdc9457SAndroid Build Coastguard Worker       assert(this->y_stride_ >= this->channels_);
77*4bdc9457SAndroid Build Coastguard Worker       return this->y_stride_;
78*4bdc9457SAndroid Build Coastguard Worker     }
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
weights_type(WeightsType weights_type)81*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& weights_type(WeightsType weights_type) {
82*4bdc9457SAndroid Build Coastguard Worker     this->weights_type_ = weights_type;
83*4bdc9457SAndroid Build Coastguard Worker     return *this;
84*4bdc9457SAndroid Build Coastguard Worker   }
85*4bdc9457SAndroid Build Coastguard Worker 
weights_type()86*4bdc9457SAndroid Build Coastguard Worker   inline WeightsType weights_type() const {
87*4bdc9457SAndroid Build Coastguard Worker     return this->weights_type_;
88*4bdc9457SAndroid Build Coastguard Worker   }
89*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)90*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& iterations(size_t iterations) {
91*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
92*4bdc9457SAndroid Build Coastguard Worker     return *this;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
iterations()95*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
96*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
97*4bdc9457SAndroid Build Coastguard Worker   }
98*4bdc9457SAndroid Build Coastguard Worker 
use_weights_cache(bool use_weights_cache)99*4bdc9457SAndroid Build Coastguard Worker   inline PReLUOperatorTester& use_weights_cache(bool use_weights_cache) {
100*4bdc9457SAndroid Build Coastguard Worker     this->use_weights_cache_ = use_weights_cache;
101*4bdc9457SAndroid Build Coastguard Worker     return *this;
102*4bdc9457SAndroid Build Coastguard Worker   }
103*4bdc9457SAndroid Build Coastguard Worker 
use_weights_cache()104*4bdc9457SAndroid Build Coastguard Worker   inline bool use_weights_cache() const {
105*4bdc9457SAndroid Build Coastguard Worker     return this->use_weights_cache_;
106*4bdc9457SAndroid Build Coastguard Worker   }
107*4bdc9457SAndroid Build Coastguard Worker 
TestF16()108*4bdc9457SAndroid Build Coastguard Worker   void TestF16() const {
109*4bdc9457SAndroid Build Coastguard Worker     switch (weights_type()) {
110*4bdc9457SAndroid Build Coastguard Worker       case WeightsType::Default:
111*4bdc9457SAndroid Build Coastguard Worker         break;
112*4bdc9457SAndroid Build Coastguard Worker       case WeightsType::FP32:
113*4bdc9457SAndroid Build Coastguard Worker         break;
114*4bdc9457SAndroid Build Coastguard Worker       default:
115*4bdc9457SAndroid Build Coastguard Worker         GTEST_FAIL() << "unexpected weights type";
116*4bdc9457SAndroid Build Coastguard Worker     }
117*4bdc9457SAndroid Build Coastguard Worker 
118*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
119*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
120*4bdc9457SAndroid Build Coastguard Worker     auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f);
121*4bdc9457SAndroid Build Coastguard Worker     auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f);
122*4bdc9457SAndroid Build Coastguard Worker 
123*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
124*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> w(channels());
125*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> w_as_float(channels());
126*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
127*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size() * channels());
128*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
129*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), [&] { return fp16_ieee_from_fp32_value(f32irng(rng)); });
130*4bdc9457SAndroid Build Coastguard Worker       std::generate(w.begin(), w.end(), [&] { return fp16_ieee_from_fp32_value(f32wrng(rng)); });
131*4bdc9457SAndroid Build Coastguard Worker       std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value);
132*4bdc9457SAndroid Build Coastguard Worker       std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
133*4bdc9457SAndroid Build Coastguard Worker 
134*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
135*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
136*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
137*4bdc9457SAndroid Build Coastguard Worker           const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]);
138*4bdc9457SAndroid Build Coastguard Worker           const float w_value = w_as_float[c];
139*4bdc9457SAndroid Build Coastguard Worker           y_ref[i * channels() + c] = std::signbit(x_value) ? x_value * w_value : x_value;
140*4bdc9457SAndroid Build Coastguard Worker         }
141*4bdc9457SAndroid Build Coastguard Worker       }
142*4bdc9457SAndroid Build Coastguard Worker 
143*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy PReLU operator.
144*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
145*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t prelu_op = nullptr;
146*4bdc9457SAndroid Build Coastguard Worker 
147*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
148*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
149*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
150*4bdc9457SAndroid Build Coastguard Worker       };
151*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
152*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
153*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
154*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
155*4bdc9457SAndroid Build Coastguard Worker       }
156*4bdc9457SAndroid Build Coastguard Worker 
157*4bdc9457SAndroid Build Coastguard Worker       const void* negative_slope_data = w.data();
158*4bdc9457SAndroid Build Coastguard Worker       if (weights_type() == WeightsType::FP32) {
159*4bdc9457SAndroid Build Coastguard Worker         negative_slope_data = w_as_float.data();
160*4bdc9457SAndroid Build Coastguard Worker       }
161*4bdc9457SAndroid Build Coastguard Worker       uint32_t flags = 0;
162*4bdc9457SAndroid Build Coastguard Worker       if (weights_type() == WeightsType::FP32) {
163*4bdc9457SAndroid Build Coastguard Worker         flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
164*4bdc9457SAndroid Build Coastguard Worker       }
165*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
166*4bdc9457SAndroid Build Coastguard Worker         xnn_create_prelu_nc_f16(
167*4bdc9457SAndroid Build Coastguard Worker           channels(), x_stride(), y_stride(),
168*4bdc9457SAndroid Build Coastguard Worker           negative_slope_data,
169*4bdc9457SAndroid Build Coastguard Worker           flags, &caches, &prelu_op));
170*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, prelu_op);
171*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
172*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
173*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
174*4bdc9457SAndroid Build Coastguard Worker       }
175*4bdc9457SAndroid Build Coastguard Worker 
176*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete prelu_op.
177*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
178*4bdc9457SAndroid Build Coastguard Worker 
179*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
180*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_prelu_nc_f16(
181*4bdc9457SAndroid Build Coastguard Worker           prelu_op,
182*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
183*4bdc9457SAndroid Build Coastguard Worker           x.data(), y.data(),
184*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
185*4bdc9457SAndroid Build Coastguard Worker 
186*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
187*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(prelu_op, nullptr /* thread pool */));
188*4bdc9457SAndroid Build Coastguard Worker 
189*4bdc9457SAndroid Build Coastguard Worker       VerifyF16(y, y_ref);
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
192*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t prelu_op2 = nullptr;
193*4bdc9457SAndroid Build Coastguard Worker         const size_t old_weights_cache_size = weights_cache.cache.weights.size;
194*4bdc9457SAndroid Build Coastguard Worker 
195*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
196*4bdc9457SAndroid Build Coastguard Worker                   xnn_create_prelu_nc_f16(
197*4bdc9457SAndroid Build Coastguard Worker                       channels(), x_stride(), y_stride(),
198*4bdc9457SAndroid Build Coastguard Worker                       negative_slope_data,
199*4bdc9457SAndroid Build Coastguard Worker                       flags, &caches, &prelu_op2));
200*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, prelu_op2);
201*4bdc9457SAndroid Build Coastguard Worker 
202*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete prelu_op2.
203*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator);
204*4bdc9457SAndroid Build Coastguard Worker 
205*4bdc9457SAndroid Build Coastguard Worker         std::vector<uint16_t> y2(y.size(), UINT16_C(0x7E00) /* NaN */);
206*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
207*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_prelu_nc_f16(
208*4bdc9457SAndroid Build Coastguard Worker                       prelu_op2,
209*4bdc9457SAndroid Build Coastguard Worker                       batch_size(),
210*4bdc9457SAndroid Build Coastguard Worker                       x.data(), y2.data(),
211*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
212*4bdc9457SAndroid Build Coastguard Worker 
213*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
214*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(prelu_op2, nullptr /* thread pool */));
215*4bdc9457SAndroid Build Coastguard Worker 
216*4bdc9457SAndroid Build Coastguard Worker         VerifyF16(y2, y_ref);
217*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
218*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
219*4bdc9457SAndroid Build Coastguard Worker       }
220*4bdc9457SAndroid Build Coastguard Worker     }
221*4bdc9457SAndroid Build Coastguard Worker   }
222*4bdc9457SAndroid Build Coastguard Worker 
VerifyF16(const std::vector<uint16_t> & y,const std::vector<float> & y_ref)223*4bdc9457SAndroid Build Coastguard Worker   void VerifyF16(const std::vector<uint16_t>& y, const std::vector<float>& y_ref) const {
224*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
225*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
226*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
227*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(y[i * y_stride() + c]),
228*4bdc9457SAndroid Build Coastguard Worker             y_ref[i * channels() + c],
229*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-3f))
230*4bdc9457SAndroid Build Coastguard Worker             << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
231*4bdc9457SAndroid Build Coastguard Worker       }
232*4bdc9457SAndroid Build Coastguard Worker     }
233*4bdc9457SAndroid Build Coastguard Worker   }
234*4bdc9457SAndroid Build Coastguard Worker 
TestF32()235*4bdc9457SAndroid Build Coastguard Worker   void TestF32() const {
236*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
237*4bdc9457SAndroid Build Coastguard Worker 
238*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
239*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
240*4bdc9457SAndroid Build Coastguard Worker     auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f);
241*4bdc9457SAndroid Build Coastguard Worker     auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f);
242*4bdc9457SAndroid Build Coastguard Worker 
243*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
244*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> w(channels());
245*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
246*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size() * channels());
247*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
248*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), [&] { return f32irng(rng);} );
249*4bdc9457SAndroid Build Coastguard Worker       std::generate(w.begin(), w.end(), [&] { return f32wrng(rng);} );
250*4bdc9457SAndroid Build Coastguard Worker       std::fill(y.begin(), y.end(), nanf(""));
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
253*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
254*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
255*4bdc9457SAndroid Build Coastguard Worker           y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c];
256*4bdc9457SAndroid Build Coastguard Worker         }
257*4bdc9457SAndroid Build Coastguard Worker       }
258*4bdc9457SAndroid Build Coastguard Worker 
259*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy PReLU operator.
260*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
261*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t prelu_op = nullptr;
262*4bdc9457SAndroid Build Coastguard Worker 
263*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
264*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
265*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
266*4bdc9457SAndroid Build Coastguard Worker       };
267*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
268*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
269*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
270*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
271*4bdc9457SAndroid Build Coastguard Worker       }
272*4bdc9457SAndroid Build Coastguard Worker 
273*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
274*4bdc9457SAndroid Build Coastguard Worker         xnn_create_prelu_nc_f32(
275*4bdc9457SAndroid Build Coastguard Worker           channels(), x_stride(), y_stride(),
276*4bdc9457SAndroid Build Coastguard Worker           w.data(),
277*4bdc9457SAndroid Build Coastguard Worker           0, &caches, &prelu_op));
278*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, prelu_op);
279*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
280*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
281*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
282*4bdc9457SAndroid Build Coastguard Worker       }
283*4bdc9457SAndroid Build Coastguard Worker 
284*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete prelu_op.
285*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
286*4bdc9457SAndroid Build Coastguard Worker 
287*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
288*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_prelu_nc_f32(
289*4bdc9457SAndroid Build Coastguard Worker           prelu_op,
290*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
291*4bdc9457SAndroid Build Coastguard Worker           x.data(), y.data(),
292*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
293*4bdc9457SAndroid Build Coastguard Worker 
294*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
295*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(prelu_op, nullptr /* thread pool */));
296*4bdc9457SAndroid Build Coastguard Worker 
297*4bdc9457SAndroid Build Coastguard Worker       VerifyF32(y, y_ref);
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
300*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t prelu_op2 = nullptr;
301*4bdc9457SAndroid Build Coastguard Worker         const size_t old_weights_cache_size = weights_cache.cache.weights.size;
302*4bdc9457SAndroid Build Coastguard Worker 
303*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
304*4bdc9457SAndroid Build Coastguard Worker                   xnn_create_prelu_nc_f32(
305*4bdc9457SAndroid Build Coastguard Worker                       channels(), x_stride(), y_stride(),
306*4bdc9457SAndroid Build Coastguard Worker                       w.data(),
307*4bdc9457SAndroid Build Coastguard Worker                       0, &caches, &prelu_op2));
308*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, prelu_op2);
309*4bdc9457SAndroid Build Coastguard Worker 
310*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete prelu_op2.
311*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator);
312*4bdc9457SAndroid Build Coastguard Worker         std::vector<float> y2(y.size(), nanf(""));
313*4bdc9457SAndroid Build Coastguard Worker 
314*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
315*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_prelu_nc_f32(
316*4bdc9457SAndroid Build Coastguard Worker                       prelu_op2,
317*4bdc9457SAndroid Build Coastguard Worker                       batch_size(),
318*4bdc9457SAndroid Build Coastguard Worker                       x.data(), y2.data(),
319*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
320*4bdc9457SAndroid Build Coastguard Worker 
321*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
322*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(prelu_op2, nullptr /* thread pool */));
323*4bdc9457SAndroid Build Coastguard Worker 
324*4bdc9457SAndroid Build Coastguard Worker         VerifyF32(y, y_ref);
325*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
326*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
327*4bdc9457SAndroid Build Coastguard Worker       }
328*4bdc9457SAndroid Build Coastguard Worker     }
329*4bdc9457SAndroid Build Coastguard Worker   }
330*4bdc9457SAndroid Build Coastguard Worker 
VerifyF32(const std::vector<float> & y,const std::vector<float> & y_ref)331*4bdc9457SAndroid Build Coastguard Worker   void VerifyF32(const std::vector<float>& y, const std::vector<float>& y_ref) const {
332*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
333*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
334*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
335*4bdc9457SAndroid Build Coastguard Worker             y[i * y_stride() + c],
336*4bdc9457SAndroid Build Coastguard Worker             y_ref[i * channels() + c],
337*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f))
338*4bdc9457SAndroid Build Coastguard Worker           << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
339*4bdc9457SAndroid Build Coastguard Worker       }
340*4bdc9457SAndroid Build Coastguard Worker     }
341*4bdc9457SAndroid Build Coastguard Worker   }
342*4bdc9457SAndroid Build Coastguard Worker 
VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)343*4bdc9457SAndroid Build Coastguard Worker   void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const {
344*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_cache.cache.hits, 1);
345*4bdc9457SAndroid Build Coastguard Worker     // Ensure that we did not write more weights to the cache because it was a cache hit.
346*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(old_size, weights_cache.cache.weights.size);
347*4bdc9457SAndroid Build Coastguard Worker   };
348*4bdc9457SAndroid Build Coastguard Worker 
349*4bdc9457SAndroid Build Coastguard Worker  private:
350*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
351*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
352*4bdc9457SAndroid Build Coastguard Worker   size_t x_stride_{0};
353*4bdc9457SAndroid Build Coastguard Worker   size_t y_stride_{0};
354*4bdc9457SAndroid Build Coastguard Worker   WeightsType weights_type_{WeightsType::Default};
355*4bdc9457SAndroid Build Coastguard Worker   bool use_weights_cache_{false};
356*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{15};
357*4bdc9457SAndroid Build Coastguard Worker };
358