xref: /aosp_15_r20/external/XNNPACK/test/runtime-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
9*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
10*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
11*4bdc9457SAndroid Build Coastguard Worker #include <cstring>
12*4bdc9457SAndroid Build Coastguard Worker #include <vector>
13*4bdc9457SAndroid Build Coastguard Worker 
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
16*4bdc9457SAndroid Build Coastguard Worker 
17*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-tester.h"
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack {
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker class RuntimeTester : public SubgraphTester {
22*4bdc9457SAndroid Build Coastguard Worker  public:
23*4bdc9457SAndroid Build Coastguard Worker   using SubgraphTester::SubgraphTester;
24*4bdc9457SAndroid Build Coastguard Worker 
25*4bdc9457SAndroid Build Coastguard Worker   template<typename T>
RunWithFusion()26*4bdc9457SAndroid Build Coastguard Worker   inline std::vector<T> RunWithFusion() {
27*4bdc9457SAndroid Build Coastguard Worker     Run();
28*4bdc9457SAndroid Build Coastguard Worker     std::vector<char>& tensor = this->external_tensors_.at(this->output_id_);
29*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float));
30*4bdc9457SAndroid Build Coastguard Worker     std::memcpy(output.data(), tensor.data(), tensor.size());
31*4bdc9457SAndroid Build Coastguard Worker     return output;
32*4bdc9457SAndroid Build Coastguard Worker   }
33*4bdc9457SAndroid Build Coastguard Worker 
34*4bdc9457SAndroid Build Coastguard Worker   template<typename T>
RunWithoutFusion()35*4bdc9457SAndroid Build Coastguard Worker   inline std::vector<T> RunWithoutFusion() {
36*4bdc9457SAndroid Build Coastguard Worker     Run(XNN_FLAG_NO_OPERATOR_FUSION);
37*4bdc9457SAndroid Build Coastguard Worker     std::vector<char>& tensor = this->external_tensors_.at(this->output_id_);
38*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float));
39*4bdc9457SAndroid Build Coastguard Worker     memcpy(output.data(), tensor.data(), tensor.size());
40*4bdc9457SAndroid Build Coastguard Worker     return output;
41*4bdc9457SAndroid Build Coastguard Worker   }
42*4bdc9457SAndroid Build Coastguard Worker 
NumOperators()43*4bdc9457SAndroid Build Coastguard Worker   size_t NumOperators() {
44*4bdc9457SAndroid Build Coastguard Worker     size_t count = 0;
45*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < runtime_->num_ops; i++) {
46*4bdc9457SAndroid Build Coastguard Worker       if (runtime_->opdata[i].operator_objects[0] != NULL) {
47*4bdc9457SAndroid Build Coastguard Worker         count++;
48*4bdc9457SAndroid Build Coastguard Worker       }
49*4bdc9457SAndroid Build Coastguard Worker     }
50*4bdc9457SAndroid Build Coastguard Worker     return count;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
53*4bdc9457SAndroid Build Coastguard Worker  private:
54*4bdc9457SAndroid Build Coastguard Worker   void Run(uint32_t flags = 0) {
55*4bdc9457SAndroid Build Coastguard Worker     xnn_runtime_t runtime = nullptr;
56*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(this->subgraph_.get(), nullptr, nullptr, flags, &runtime));
57*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(nullptr, runtime);
58*4bdc9457SAndroid Build Coastguard Worker     runtime_.reset(runtime);
59*4bdc9457SAndroid Build Coastguard Worker 
60*4bdc9457SAndroid Build Coastguard Worker     std::vector<xnn_external_value> externals;
61*4bdc9457SAndroid Build Coastguard Worker     for (auto it = this->external_tensors_.begin(); it != this->external_tensors_.end(); ++it) {
62*4bdc9457SAndroid Build Coastguard Worker       if (it->first == this->output_id_) {
63*4bdc9457SAndroid Build Coastguard Worker         // Scramble output tensor.
64*4bdc9457SAndroid Build Coastguard Worker         std::fill(it->second.begin(), it->second.end(), 0xA8);
65*4bdc9457SAndroid Build Coastguard Worker       }
66*4bdc9457SAndroid Build Coastguard Worker       externals.push_back(xnn_external_value{it->first, it->second.data()});
67*4bdc9457SAndroid Build Coastguard Worker     }
68*4bdc9457SAndroid Build Coastguard Worker 
69*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, externals.size(), externals.data()));
70*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
71*4bdc9457SAndroid Build Coastguard Worker   };
72*4bdc9457SAndroid Build Coastguard Worker 
73*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{nullptr, xnn_delete_runtime};
74*4bdc9457SAndroid Build Coastguard Worker };
75*4bdc9457SAndroid Build Coastguard Worker 
76*4bdc9457SAndroid Build Coastguard Worker }  // namespace xnnpack
77