xref: /aosp_15_r20/external/armnn/samples/KeywordSpotting/src/Main.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <iostream>
6 #include <map>
7 #include <vector>
8 #include <algorithm>
9 #include <cmath>
10 #include "KeywordSpottingPipeline.hpp"
11 #include "CmdArgsParser.hpp"
12 #include "ArmnnNetworkExecutor.hpp"
13 #include "AudioCapture.hpp"
14 
15 const std::string AUDIO_FILE_PATH = "--audio-file-path";
16 const std::string MODEL_FILE_PATH = "--model-file-path";
17 const std::string LABEL_PATH = "--label-path";
18 const std::string PREFERRED_BACKENDS = "--preferred-backends";
19 const std::string HELP = "--help";
20 
21 /*
22  * The accepted options for this Speech Recognition executable
23  */
24 static std::map<std::string, std::string> CMD_OPTIONS =
25 {
26         {AUDIO_FILE_PATH,    "[REQUIRED] Path to the Audio file to run speech recognition on"},
27         {MODEL_FILE_PATH,    "[REQUIRED] Path to the Speech Recognition model to use"},
28         {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
29                              " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
30                              " Defaults to CpuAcc,CpuRef"}
31 };
32 
33 /*
34  * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
35  */
GetPreferredBackendList(const std::string & preferredBackends)36 std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
37 {
38     std::vector<armnn::BackendId> backends;
39     std::stringstream ss(preferredBackends);
40 
41     while (ss.good())
42     {
43         std::string backend;
44         std::getline(ss, backend, ',');
45         backends.emplace_back(backend);
46     }
47     return backends;
48 }
49 
50 //Labels for this model
51 std::map<int, std::string> labels =
52 {
53         {0,  "silence"},
54         {1,  "unknown"},
55         {2,  "yes"},
56         {3,  "no"},
57         {4,  "up"},
58         {5,  "down"},
59         {6,  "left"},
60         {7,  "right"},
61         {8,  "on"},
62         {9,  "off"},
63         {10, "stop"},
64         {11, "go"}
65 };
66 
67 
main(int argc,char * argv[])68 int main(int argc, char* argv[])
69 {
70     printf("ArmNN major version: %d\n", ARMNN_MAJOR_VERSION);
71     std::map<std::string, std::string> options;
72 
73     //Read command line args
74     int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
75     if (result != 0)
76     {
77         return result;
78     }
79 
80     // Create the ArmNN inference runner
81     common::PipelineOptions pipelineOptions;
82     pipelineOptions.m_ModelName = "DS_CNN_CLUSTERED_INT8";
83     pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
84     if (CheckOptionSpecified(options, PREFERRED_BACKENDS))
85     {
86         pipelineOptions.m_backends = GetPreferredBackendList(
87             (GetSpecifiedOption(options, PREFERRED_BACKENDS)));
88     }
89     else
90     {
91         pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
92     }
93 
94     kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(pipelineOptions);
95 
96     //Extract audio data from sound file
97     auto filePath = GetSpecifiedOption(options, AUDIO_FILE_PATH);
98     std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(filePath);
99 
100     audio::AudioCapture capture;
101     //todo: read samples and stride from pipeline
102     capture.InitSlidingWindow(audioData.data(),
103                               audioData.size(),
104                               kwsPipeline->getInputSamplesSize(),
105                               kwsPipeline->getInputSamplesSize()/2);
106 
107     //Loop through audio data buffer
108     while (capture.HasNext())
109     {
110         std::vector<float> audioBlock = capture.Next();
111         common::InferenceResults<int8_t> results;
112 
113         //Prepare input tensors
114         std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(audioBlock);
115         //Run inference
116         kwsPipeline->Inference(preprocessedData, results);
117         //Decode output
118         kwsPipeline->PostProcessing(results, labels,
119                                     [](int index, std::string& label, float prob) -> void {
120                                         printf("Keyword \"%s\", index %d:, probability %f\n",
121                                                label.c_str(),
122                                                index,
123                                                prob);
124                                     });
125     }
126 
127     return 0;
128 }