xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/include/SpeechRecognitionPipeline.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnNetworkExecutor.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "Decoder.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "MFCC.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterPreprocessor.hpp"
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker namespace asr
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker /**
16*89c4ff92SAndroid Build Coastguard Worker  * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference
17*89c4ff92SAndroid Build Coastguard Worker  * result post-processing.
18*89c4ff92SAndroid Build Coastguard Worker  *
19*89c4ff92SAndroid Build Coastguard Worker  */
20*89c4ff92SAndroid Build Coastguard Worker class ASRPipeline
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker public:
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     /**
25*89c4ff92SAndroid Build Coastguard Worker      * Creates speech recognition pipeline with given network executor and decoder.
26*89c4ff92SAndroid Build Coastguard Worker      * @param executor - unique pointer to inference runner
27*89c4ff92SAndroid Build Coastguard Worker      * @param decoder - unique pointer to inference results decoder
28*89c4ff92SAndroid Build Coastguard Worker      */
29*89c4ff92SAndroid Build Coastguard Worker     ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
30*89c4ff92SAndroid Build Coastguard Worker                 std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor);
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     /**
33*89c4ff92SAndroid Build Coastguard Worker      * @brief Standard audio pre-processing implementation.
34*89c4ff92SAndroid Build Coastguard Worker      *
35*89c4ff92SAndroid Build Coastguard Worker      * Preprocesses and prepares the data for inference by
36*89c4ff92SAndroid Build Coastguard Worker      * extracting the MFCC features.
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker      * @param[in] audio - the raw audio data
39*89c4ff92SAndroid Build Coastguard Worker      * @param[out] preprocessor - the preprocessor object, which handles the data preparation
40*89c4ff92SAndroid Build Coastguard Worker      */
41*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> PreProcessing(std::vector<float>& audio);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     int getInputSamplesSize();
44*89c4ff92SAndroid Build Coastguard Worker     int getSlidingWindowOffset();
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself
47*89c4ff92SAndroid Build Coastguard Worker     // Will need to be refactored so that hard coded values are not defined outside of model settings
48*89c4ff92SAndroid Build Coastguard Worker     int SLIDING_WINDOW_OFFSET;
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     /**
51*89c4ff92SAndroid Build Coastguard Worker      * @brief Executes inference
52*89c4ff92SAndroid Build Coastguard Worker      *
53*89c4ff92SAndroid Build Coastguard Worker      * Calls inference runner provided during instance construction.
54*89c4ff92SAndroid Build Coastguard Worker      *
55*89c4ff92SAndroid Build Coastguard Worker      * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
56*89c4ff92SAndroid Build Coastguard Worker      * @param[out] result - raw inference results.
57*89c4ff92SAndroid Build Coastguard Worker      */
58*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
Inference(const std::vector<T> & preprocessedData,common::InferenceResults<int8_t> & result)59*89c4ff92SAndroid Build Coastguard Worker     void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
60*89c4ff92SAndroid Build Coastguard Worker     {
61*89c4ff92SAndroid Build Coastguard Worker         size_t data_bytes = sizeof(T) * preprocessedData.size();
62*89c4ff92SAndroid Build Coastguard Worker         m_executor->Run(preprocessedData.data(), data_bytes, result);
63*89c4ff92SAndroid Build Coastguard Worker     }
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     /**
66*89c4ff92SAndroid Build Coastguard Worker      * @brief Standard inference results post-processing implementation.
67*89c4ff92SAndroid Build Coastguard Worker      *
68*89c4ff92SAndroid Build Coastguard Worker      * Decodes inference results using decoder provided during construction.
69*89c4ff92SAndroid Build Coastguard Worker      *
70*89c4ff92SAndroid Build Coastguard Worker      * @param[in] inferenceResult - inference results to be decoded.
71*89c4ff92SAndroid Build Coastguard Worker      * @param[in] isFirstWindow - for checking if this is the first window of the sliding window.
72*89c4ff92SAndroid Build Coastguard Worker      * @param[in] isLastWindow - for checking if this is the last window of the sliding window.
73*89c4ff92SAndroid Build Coastguard Worker      * @param[in] currentRContext - the right context of the output text. To be output if it is the last window.
74*89c4ff92SAndroid Build Coastguard Worker      */
75*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
PostProcessing(common::InferenceResults<int8_t> & inferenceResult,bool & isFirstWindow,bool isLastWindow,std::string currentRContext)76*89c4ff92SAndroid Build Coastguard Worker     void PostProcessing(common::InferenceResults<int8_t>& inferenceResult,
77*89c4ff92SAndroid Build Coastguard Worker                         bool& isFirstWindow,
78*89c4ff92SAndroid Build Coastguard Worker                         bool isLastWindow,
79*89c4ff92SAndroid Build Coastguard Worker                         std::string currentRContext)
80*89c4ff92SAndroid Build Coastguard Worker     {
81*89c4ff92SAndroid Build Coastguard Worker         int rowLength = 29;
82*89c4ff92SAndroid Build Coastguard Worker         int middleContextStart = 49;
83*89c4ff92SAndroid Build Coastguard Worker         int middleContextEnd = 99;
84*89c4ff92SAndroid Build Coastguard Worker         int leftContextStart = 0;
85*89c4ff92SAndroid Build Coastguard Worker         int rightContextStart = 100;
86*89c4ff92SAndroid Build Coastguard Worker         int rightContextEnd = 148;
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker         std::vector<T> contextToProcess;
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker         // If isFirstWindow we keep the left context of the output
91*89c4ff92SAndroid Build Coastguard Worker         if (isFirstWindow)
92*89c4ff92SAndroid Build Coastguard Worker         {
93*89c4ff92SAndroid Build Coastguard Worker             std::vector<T> chunk(&inferenceResult[0][leftContextStart],
94*89c4ff92SAndroid Build Coastguard Worker                                  &inferenceResult[0][middleContextEnd * rowLength]);
95*89c4ff92SAndroid Build Coastguard Worker             contextToProcess = chunk;
96*89c4ff92SAndroid Build Coastguard Worker         }
97*89c4ff92SAndroid Build Coastguard Worker         else
98*89c4ff92SAndroid Build Coastguard Worker         {
99*89c4ff92SAndroid Build Coastguard Worker             // Else we only keep the middle context of the output
100*89c4ff92SAndroid Build Coastguard Worker             std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength],
101*89c4ff92SAndroid Build Coastguard Worker                                  &inferenceResult[0][middleContextEnd * rowLength]);
102*89c4ff92SAndroid Build Coastguard Worker             contextToProcess = chunk;
103*89c4ff92SAndroid Build Coastguard Worker         }
104*89c4ff92SAndroid Build Coastguard Worker         std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess);
105*89c4ff92SAndroid Build Coastguard Worker         isFirstWindow = false;
106*89c4ff92SAndroid Build Coastguard Worker         std::cout << output << std::flush;
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker         // If this is the last window, we print the right context of the output
109*89c4ff92SAndroid Build Coastguard Worker         if (isLastWindow)
110*89c4ff92SAndroid Build Coastguard Worker         {
111*89c4ff92SAndroid Build Coastguard Worker             std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength],
112*89c4ff92SAndroid Build Coastguard Worker                                     &inferenceResult[0][rightContextEnd * rowLength]);
113*89c4ff92SAndroid Build Coastguard Worker             currentRContext = this->m_decoder->DecodeOutput(rContext);
114*89c4ff92SAndroid Build Coastguard Worker             std::cout << currentRContext << std::endl;
115*89c4ff92SAndroid Build Coastguard Worker         }
116*89c4ff92SAndroid Build Coastguard Worker     }
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker protected:
119*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
120*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<Decoder> m_decoder;
121*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor;
122*89c4ff92SAndroid Build Coastguard Worker };
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker /**
127*89c4ff92SAndroid Build Coastguard Worker  * Constructs speech recognition pipeline based on configuration provided.
128*89c4ff92SAndroid Build Coastguard Worker  *
129*89c4ff92SAndroid Build Coastguard Worker  * @param[in] config - speech recognition pipeline configuration.
130*89c4ff92SAndroid Build Coastguard Worker  * @param[in] labels - asr labels
131*89c4ff92SAndroid Build Coastguard Worker  *
132*89c4ff92SAndroid Build Coastguard Worker  * @return unique pointer to asr pipeline.
133*89c4ff92SAndroid Build Coastguard Worker  */
134*89c4ff92SAndroid Build Coastguard Worker IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels);
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker } // namespace asr