xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/src/Main.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
6*89c4ff92SAndroid Build Coastguard Worker #include <map>
7*89c4ff92SAndroid Build Coastguard Worker #include <vector>
8*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
9*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include "CmdArgsParser.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnNetworkExecutor.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "AudioCapture.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "SpeechRecognitionPipeline.hpp"
15*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp"
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker using InferenceResult = std::vector<int8_t>;
18*89c4ff92SAndroid Build Coastguard Worker using InferenceResults = std::vector<InferenceResult>;
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker const std::string AUDIO_FILE_PATH = "--audio-file-path";
21*89c4ff92SAndroid Build Coastguard Worker const std::string MODEL_FILE_PATH = "--model-file-path";
22*89c4ff92SAndroid Build Coastguard Worker const std::string LABEL_PATH = "--label-path";
23*89c4ff92SAndroid Build Coastguard Worker const std::string PREFERRED_BACKENDS = "--preferred-backends";
24*89c4ff92SAndroid Build Coastguard Worker const std::string HELP = "--help";
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::string> labels =
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker         {0,  "a"},
29*89c4ff92SAndroid Build Coastguard Worker         {1,  "b"},
30*89c4ff92SAndroid Build Coastguard Worker         {2,  "c"},
31*89c4ff92SAndroid Build Coastguard Worker         {3,  "d"},
32*89c4ff92SAndroid Build Coastguard Worker         {4,  "e"},
33*89c4ff92SAndroid Build Coastguard Worker         {5,  "f"},
34*89c4ff92SAndroid Build Coastguard Worker         {6,  "g"},
35*89c4ff92SAndroid Build Coastguard Worker         {7,  "h"},
36*89c4ff92SAndroid Build Coastguard Worker         {8,  "i"},
37*89c4ff92SAndroid Build Coastguard Worker         {9,  "j"},
38*89c4ff92SAndroid Build Coastguard Worker         {10, "k"},
39*89c4ff92SAndroid Build Coastguard Worker         {11, "l"},
40*89c4ff92SAndroid Build Coastguard Worker         {12, "m"},
41*89c4ff92SAndroid Build Coastguard Worker         {13, "n"},
42*89c4ff92SAndroid Build Coastguard Worker         {14, "o"},
43*89c4ff92SAndroid Build Coastguard Worker         {15, "p"},
44*89c4ff92SAndroid Build Coastguard Worker         {16, "q"},
45*89c4ff92SAndroid Build Coastguard Worker         {17, "r"},
46*89c4ff92SAndroid Build Coastguard Worker         {18, "s"},
47*89c4ff92SAndroid Build Coastguard Worker         {19, "t"},
48*89c4ff92SAndroid Build Coastguard Worker         {20, "u"},
49*89c4ff92SAndroid Build Coastguard Worker         {21, "v"},
50*89c4ff92SAndroid Build Coastguard Worker         {22, "w"},
51*89c4ff92SAndroid Build Coastguard Worker         {23, "x"},
52*89c4ff92SAndroid Build Coastguard Worker         {24, "y"},
53*89c4ff92SAndroid Build Coastguard Worker         {25, "z"},
54*89c4ff92SAndroid Build Coastguard Worker         {26, "\'"},
55*89c4ff92SAndroid Build Coastguard Worker         {27, " "},
56*89c4ff92SAndroid Build Coastguard Worker         {28, "$"}
57*89c4ff92SAndroid Build Coastguard Worker };
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker /*
60*89c4ff92SAndroid Build Coastguard Worker  * The accepted options for this Speech Recognition executable
61*89c4ff92SAndroid Build Coastguard Worker  */
62*89c4ff92SAndroid Build Coastguard Worker static std::map<std::string, std::string> CMD_OPTIONS =
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker     {AUDIO_FILE_PATH,    "[REQUIRED] Path to the Audio file to run speech recognition on"},
65*89c4ff92SAndroid Build Coastguard Worker     {MODEL_FILE_PATH,    "[REQUIRED] Path to the Speech Recognition model to use"},
66*89c4ff92SAndroid Build Coastguard Worker     {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
67*89c4ff92SAndroid Build Coastguard Worker                          " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
68*89c4ff92SAndroid Build Coastguard Worker                          " Defaults to CpuAcc,CpuRef"}
69*89c4ff92SAndroid Build Coastguard Worker };
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker /*
72*89c4ff92SAndroid Build Coastguard Worker  * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
73*89c4ff92SAndroid Build Coastguard Worker  */
GetPreferredBackendList(const std::string & preferredBackends)74*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends;
77*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss(preferredBackends);
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     while (ss.good())
80*89c4ff92SAndroid Build Coastguard Worker     {
81*89c4ff92SAndroid Build Coastguard Worker         std::string backend;
82*89c4ff92SAndroid Build Coastguard Worker         std::getline(ss, backend, ',');
83*89c4ff92SAndroid Build Coastguard Worker         backends.emplace_back(backend);
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker     return backends;
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])88*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
89*89c4ff92SAndroid Build Coastguard Worker {
90*89c4ff92SAndroid Build Coastguard Worker     bool isFirstWindow = true;
91*89c4ff92SAndroid Build Coastguard Worker     std::string currentRContext = "";
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker     std::map<std::string, std::string> options;
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
96*89c4ff92SAndroid Build Coastguard Worker     if (result != 0)
97*89c4ff92SAndroid Build Coastguard Worker     {
98*89c4ff92SAndroid Build Coastguard Worker         return result;
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     // Create the network options
102*89c4ff92SAndroid Build Coastguard Worker     common::PipelineOptions pipelineOptions;
103*89c4ff92SAndroid Build Coastguard Worker     pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
104*89c4ff92SAndroid Build Coastguard Worker     pipelineOptions.m_ModelName = "Wav2Letter";
105*89c4ff92SAndroid Build Coastguard Worker     if (CheckOptionSpecified(options, PREFERRED_BACKENDS))
106*89c4ff92SAndroid Build Coastguard Worker     {
107*89c4ff92SAndroid Build Coastguard Worker         pipelineOptions.m_backends = GetPreferredBackendList((GetSpecifiedOption(options, PREFERRED_BACKENDS)));
108*89c4ff92SAndroid Build Coastguard Worker     }
109*89c4ff92SAndroid Build Coastguard Worker     else
110*89c4ff92SAndroid Build Coastguard Worker     {
111*89c4ff92SAndroid Build Coastguard Worker         pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
112*89c4ff92SAndroid Build Coastguard Worker     }
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     asr::IPipelinePtr asrPipeline = asr::CreatePipeline(pipelineOptions, labels);
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     audio::AudioCapture capture;
117*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(GetSpecifiedOption(options, AUDIO_FILE_PATH));
118*89c4ff92SAndroid Build Coastguard Worker     capture.InitSlidingWindow(audioData.data(), audioData.size(), asrPipeline->getInputSamplesSize(),
119*89c4ff92SAndroid Build Coastguard Worker                               asrPipeline->getSlidingWindowOffset());
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     while (capture.HasNext())
122*89c4ff92SAndroid Build Coastguard Worker     {
123*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> audioBlock = capture.Next();
124*89c4ff92SAndroid Build Coastguard Worker         InferenceResults results;
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> preprocessedData = asrPipeline->PreProcessing(audioBlock);
127*89c4ff92SAndroid Build Coastguard Worker         asrPipeline->Inference<int8_t>(preprocessedData, results);
128*89c4ff92SAndroid Build Coastguard Worker         asrPipeline->PostProcessing<int8_t>(results, isFirstWindow, !capture.HasNext(), currentRContext);
129*89c4ff92SAndroid Build Coastguard Worker     }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     return 0;
132*89c4ff92SAndroid Build Coastguard Worker }