1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
6*89c4ff92SAndroid Build Coastguard Worker #include <map>
7*89c4ff92SAndroid Build Coastguard Worker #include <vector>
8*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
9*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include "CmdArgsParser.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnNetworkExecutor.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "AudioCapture.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "SpeechRecognitionPipeline.hpp"
15*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp"
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker using InferenceResult = std::vector<int8_t>;
18*89c4ff92SAndroid Build Coastguard Worker using InferenceResults = std::vector<InferenceResult>;
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker const std::string AUDIO_FILE_PATH = "--audio-file-path";
21*89c4ff92SAndroid Build Coastguard Worker const std::string MODEL_FILE_PATH = "--model-file-path";
22*89c4ff92SAndroid Build Coastguard Worker const std::string LABEL_PATH = "--label-path";
23*89c4ff92SAndroid Build Coastguard Worker const std::string PREFERRED_BACKENDS = "--preferred-backends";
24*89c4ff92SAndroid Build Coastguard Worker const std::string HELP = "--help";
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::string> labels =
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker {0, "a"},
29*89c4ff92SAndroid Build Coastguard Worker {1, "b"},
30*89c4ff92SAndroid Build Coastguard Worker {2, "c"},
31*89c4ff92SAndroid Build Coastguard Worker {3, "d"},
32*89c4ff92SAndroid Build Coastguard Worker {4, "e"},
33*89c4ff92SAndroid Build Coastguard Worker {5, "f"},
34*89c4ff92SAndroid Build Coastguard Worker {6, "g"},
35*89c4ff92SAndroid Build Coastguard Worker {7, "h"},
36*89c4ff92SAndroid Build Coastguard Worker {8, "i"},
37*89c4ff92SAndroid Build Coastguard Worker {9, "j"},
38*89c4ff92SAndroid Build Coastguard Worker {10, "k"},
39*89c4ff92SAndroid Build Coastguard Worker {11, "l"},
40*89c4ff92SAndroid Build Coastguard Worker {12, "m"},
41*89c4ff92SAndroid Build Coastguard Worker {13, "n"},
42*89c4ff92SAndroid Build Coastguard Worker {14, "o"},
43*89c4ff92SAndroid Build Coastguard Worker {15, "p"},
44*89c4ff92SAndroid Build Coastguard Worker {16, "q"},
45*89c4ff92SAndroid Build Coastguard Worker {17, "r"},
46*89c4ff92SAndroid Build Coastguard Worker {18, "s"},
47*89c4ff92SAndroid Build Coastguard Worker {19, "t"},
48*89c4ff92SAndroid Build Coastguard Worker {20, "u"},
49*89c4ff92SAndroid Build Coastguard Worker {21, "v"},
50*89c4ff92SAndroid Build Coastguard Worker {22, "w"},
51*89c4ff92SAndroid Build Coastguard Worker {23, "x"},
52*89c4ff92SAndroid Build Coastguard Worker {24, "y"},
53*89c4ff92SAndroid Build Coastguard Worker {25, "z"},
54*89c4ff92SAndroid Build Coastguard Worker {26, "\'"},
55*89c4ff92SAndroid Build Coastguard Worker {27, " "},
56*89c4ff92SAndroid Build Coastguard Worker {28, "$"}
57*89c4ff92SAndroid Build Coastguard Worker };
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker /*
60*89c4ff92SAndroid Build Coastguard Worker * The accepted options for this Speech Recognition executable
61*89c4ff92SAndroid Build Coastguard Worker */
62*89c4ff92SAndroid Build Coastguard Worker static std::map<std::string, std::string> CMD_OPTIONS =
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker {AUDIO_FILE_PATH, "[REQUIRED] Path to the Audio file to run speech recognition on"},
65*89c4ff92SAndroid Build Coastguard Worker {MODEL_FILE_PATH, "[REQUIRED] Path to the Speech Recognition model to use"},
66*89c4ff92SAndroid Build Coastguard Worker {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
67*89c4ff92SAndroid Build Coastguard Worker " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
68*89c4ff92SAndroid Build Coastguard Worker " Defaults to CpuAcc,CpuRef"}
69*89c4ff92SAndroid Build Coastguard Worker };
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker /*
72*89c4ff92SAndroid Build Coastguard Worker * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
73*89c4ff92SAndroid Build Coastguard Worker */
GetPreferredBackendList(const std::string & preferredBackends)74*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends;
77*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss(preferredBackends);
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker while (ss.good())
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker std::string backend;
82*89c4ff92SAndroid Build Coastguard Worker std::getline(ss, backend, ',');
83*89c4ff92SAndroid Build Coastguard Worker backends.emplace_back(backend);
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker return backends;
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker
main(int argc,char * argv[])88*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
89*89c4ff92SAndroid Build Coastguard Worker {
90*89c4ff92SAndroid Build Coastguard Worker bool isFirstWindow = true;
91*89c4ff92SAndroid Build Coastguard Worker std::string currentRContext = "";
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, std::string> options;
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
96*89c4ff92SAndroid Build Coastguard Worker if (result != 0)
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker return result;
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker // Create the network options
102*89c4ff92SAndroid Build Coastguard Worker common::PipelineOptions pipelineOptions;
103*89c4ff92SAndroid Build Coastguard Worker pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
104*89c4ff92SAndroid Build Coastguard Worker pipelineOptions.m_ModelName = "Wav2Letter";
105*89c4ff92SAndroid Build Coastguard Worker if (CheckOptionSpecified(options, PREFERRED_BACKENDS))
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker pipelineOptions.m_backends = GetPreferredBackendList((GetSpecifiedOption(options, PREFERRED_BACKENDS)));
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker else
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker asr::IPipelinePtr asrPipeline = asr::CreatePipeline(pipelineOptions, labels);
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker audio::AudioCapture capture;
117*89c4ff92SAndroid Build Coastguard Worker std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(GetSpecifiedOption(options, AUDIO_FILE_PATH));
118*89c4ff92SAndroid Build Coastguard Worker capture.InitSlidingWindow(audioData.data(), audioData.size(), asrPipeline->getInputSamplesSize(),
119*89c4ff92SAndroid Build Coastguard Worker asrPipeline->getSlidingWindowOffset());
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker while (capture.HasNext())
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker std::vector<float> audioBlock = capture.Next();
124*89c4ff92SAndroid Build Coastguard Worker InferenceResults results;
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> preprocessedData = asrPipeline->PreProcessing(audioBlock);
127*89c4ff92SAndroid Build Coastguard Worker asrPipeline->Inference<int8_t>(preprocessedData, results);
128*89c4ff92SAndroid Build Coastguard Worker asrPipeline->PostProcessing<int8_t>(results, isFirstWindow, !capture.HasNext(), currentRContext);
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker
131*89c4ff92SAndroid Build Coastguard Worker return 0;
132*89c4ff92SAndroid Build Coastguard Worker }