xref: /aosp_15_r20/external/armnn/samples/KeywordSpotting/src/KeywordSpottingPipeline.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "KeywordSpottingPipeline.hpp"
7 #include "ArmnnNetworkExecutor.hpp"
8 #include "DsCNNPreprocessor.hpp"
9 
10 namespace kws
11 {
KWSPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,std::unique_ptr<Decoder> decoder,std::unique_ptr<DsCNNPreprocessor> preProcessor)12 KWSPipeline::KWSPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
13                          std::unique_ptr<Decoder> decoder,
14                          std::unique_ptr<DsCNNPreprocessor> preProcessor
15                          ) :
16         m_executor(std::move(executor)),
17         m_decoder(std::move(decoder)),
18         m_preProcessor(std::move(preProcessor)) {}
19 
20 
PreProcessing(std::vector<float> & audio)21 std::vector<int8_t> KWSPipeline::PreProcessing(std::vector<float>& audio)
22 {
23     return m_preProcessor->Invoke(audio.data(), audio.size(), m_executor->GetQuantizationOffset(),
24                                   m_executor->GetQuantizationScale());
25 }
26 
Inference(const std::vector<int8_t> & preprocessedData,common::InferenceResults<int8_t> & result)27 void KWSPipeline::Inference(const std::vector<int8_t>& preprocessedData,
28                             common::InferenceResults<int8_t>& result)
29 {
30     m_executor->Run(preprocessedData.data(), preprocessedData.size(), result);
31 }
32 
PostProcessing(common::InferenceResults<int8_t> & inferenceResults,std::map<int,std::string> & labels,const std::function<void (int,std::string &,float)> & callback)33 void KWSPipeline::PostProcessing(common::InferenceResults<int8_t>& inferenceResults,
34                     std::map<int, std::string>& labels,
35                     const std::function<void (int, std::string&, float)>& callback)
36 {
37     std::pair<int,float> outputDecoder = this->m_decoder->decodeOutput(inferenceResults[0]);
38     int keywordIndex = std::get<0>(outputDecoder);
39     std::string output = labels[keywordIndex];
40     callback(keywordIndex, output, std::get<1>(outputDecoder));
41 }
42 
getInputSamplesSize()43 int KWSPipeline::getInputSamplesSize()
44 {
45     return this->m_preProcessor->m_windowLen +
46             ((this->m_preProcessor->m_mfcc->m_params.m_numMfccVectors - 1) *
47               this->m_preProcessor->m_windowStride);
48 }
49 
CreatePipeline(common::PipelineOptions & config)50 IPipelinePtr CreatePipeline(common::PipelineOptions& config)
51 {
52     if (config.m_ModelName == "DS_CNN_CLUSTERED_INT8")
53     {
54         //DS-CNN model settings
55         float SAMP_FREQ = 16000;
56         int MFCC_WINDOW_LEN = 640;
57         int MFCC_WINDOW_STRIDE = 320;
58         int NUM_MFCC_FEATS = 10;
59         int NUM_MFCC_VECTORS = 49;
60         //todo: calc in pipeline and use in main
61         int SAMPLES_PER_INFERENCE = NUM_MFCC_VECTORS * MFCC_WINDOW_STRIDE +
62                                     MFCC_WINDOW_LEN - MFCC_WINDOW_STRIDE; //16000
63         float MEL_LO_FREQ = 20;
64         float MEL_HI_FREQ = 4000;
65         int NUM_FBANK_BIN = 40;
66 
67         MfccParams mfccParams(SAMP_FREQ,
68                               NUM_FBANK_BIN,
69                               MEL_LO_FREQ,
70                               MEL_HI_FREQ,
71                               NUM_MFCC_FEATS,
72                               MFCC_WINDOW_LEN, false,
73                               NUM_MFCC_VECTORS);
74 
75         std::unique_ptr<DsCnnMFCC> mfccInst = std::make_unique<DsCnnMFCC>(mfccParams);
76         auto preprocessor = std::make_unique<kws::DsCNNPreprocessor>(
77             MFCC_WINDOW_LEN, MFCC_WINDOW_STRIDE, std::move(mfccInst));
78 
79         auto executor = std::make_unique<common::ArmnnNetworkExecutor<int8_t>>(
80             config.m_ModelFilePath, config.m_backends);
81 
82         auto decoder = std::make_unique<kws::Decoder>(executor->GetOutputQuantizationOffset(0),
83                                                       executor->GetOutputQuantizationScale(0));
84 
85         return std::make_unique<kws::KWSPipeline>(std::move(executor),
86                                                   std::move(decoder), std::move(preprocessor));
87     }
88     else
89     {
90         throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " .");
91     }
92 }
93 
94 };// namespace kws