1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 9*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 10*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 11*4bdc9457SAndroid Build Coastguard Worker #include <cstring> 12*4bdc9457SAndroid Build Coastguard Worker #include <vector> 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h> 16*4bdc9457SAndroid Build Coastguard Worker 17*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-tester.h" 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack { 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker class RuntimeTester : public SubgraphTester { 22*4bdc9457SAndroid Build Coastguard Worker public: 23*4bdc9457SAndroid Build Coastguard Worker using SubgraphTester::SubgraphTester; 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker template<typename T> RunWithFusion()26*4bdc9457SAndroid Build Coastguard Worker inline std::vector<T> RunWithFusion() { 27*4bdc9457SAndroid Build Coastguard Worker Run(); 28*4bdc9457SAndroid Build Coastguard Worker std::vector<char>& tensor = this->external_tensors_.at(this->output_id_); 29*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float)); 30*4bdc9457SAndroid Build Coastguard Worker std::memcpy(output.data(), tensor.data(), tensor.size()); 31*4bdc9457SAndroid Build Coastguard Worker return output; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker 34*4bdc9457SAndroid Build Coastguard Worker template<typename T> RunWithoutFusion()35*4bdc9457SAndroid Build Coastguard Worker inline std::vector<T> RunWithoutFusion() { 36*4bdc9457SAndroid Build Coastguard Worker Run(XNN_FLAG_NO_OPERATOR_FUSION); 37*4bdc9457SAndroid Build Coastguard Worker std::vector<char>& tensor = this->external_tensors_.at(this->output_id_); 38*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float)); 39*4bdc9457SAndroid Build Coastguard Worker memcpy(output.data(), tensor.data(), tensor.size()); 40*4bdc9457SAndroid Build Coastguard Worker return output; 41*4bdc9457SAndroid Build Coastguard Worker } 42*4bdc9457SAndroid Build Coastguard Worker NumOperators()43*4bdc9457SAndroid Build Coastguard Worker size_t NumOperators() { 44*4bdc9457SAndroid Build Coastguard Worker size_t count = 0; 45*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < runtime_->num_ops; i++) { 46*4bdc9457SAndroid Build Coastguard Worker if (runtime_->opdata[i].operator_objects[0] != NULL) { 47*4bdc9457SAndroid Build Coastguard Worker count++; 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker } 50*4bdc9457SAndroid Build Coastguard Worker return count; 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker 53*4bdc9457SAndroid Build Coastguard Worker private: 54*4bdc9457SAndroid Build Coastguard Worker void Run(uint32_t flags = 0) { 55*4bdc9457SAndroid Build Coastguard Worker xnn_runtime_t runtime = nullptr; 56*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(this->subgraph_.get(), nullptr, nullptr, flags, &runtime)); 57*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, runtime); 58*4bdc9457SAndroid Build Coastguard Worker runtime_.reset(runtime); 59*4bdc9457SAndroid Build Coastguard Worker 60*4bdc9457SAndroid Build Coastguard Worker std::vector<xnn_external_value> externals; 61*4bdc9457SAndroid Build Coastguard Worker for (auto it = this->external_tensors_.begin(); it != this->external_tensors_.end(); ++it) { 62*4bdc9457SAndroid Build Coastguard Worker if (it->first == this->output_id_) { 63*4bdc9457SAndroid Build Coastguard Worker // Scramble output tensor. 64*4bdc9457SAndroid Build Coastguard Worker std::fill(it->second.begin(), it->second.end(), 0xA8); 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker externals.push_back(xnn_external_value{it->first, it->second.data()}); 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker 69*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, externals.size(), externals.data())); 70*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); 71*4bdc9457SAndroid Build Coastguard Worker }; 72*4bdc9457SAndroid Build Coastguard Worker 73*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{nullptr, xnn_delete_runtime}; 74*4bdc9457SAndroid Build Coastguard Worker }; 75*4bdc9457SAndroid Build Coastguard Worker 76*4bdc9457SAndroid Build Coastguard Worker } // namespace xnnpack 77