1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnNetworkExecutor.hpp" 9*89c4ff92SAndroid Build Coastguard Worker #include "Decoder.hpp" 10*89c4ff92SAndroid Build Coastguard Worker #include "MFCC.hpp" 11*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterPreprocessor.hpp" 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker namespace asr 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker /** 16*89c4ff92SAndroid Build Coastguard Worker * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference 17*89c4ff92SAndroid Build Coastguard Worker * result post-processing. 18*89c4ff92SAndroid Build Coastguard Worker * 19*89c4ff92SAndroid Build Coastguard Worker */ 20*89c4ff92SAndroid Build Coastguard Worker class ASRPipeline 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker public: 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker /** 25*89c4ff92SAndroid Build Coastguard Worker * Creates speech recognition pipeline with given network executor and decoder. 26*89c4ff92SAndroid Build Coastguard Worker * @param executor - unique pointer to inference runner 27*89c4ff92SAndroid Build Coastguard Worker * @param decoder - unique pointer to inference results decoder 28*89c4ff92SAndroid Build Coastguard Worker */ 29*89c4ff92SAndroid Build Coastguard Worker ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor, 30*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor); 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker /** 33*89c4ff92SAndroid Build Coastguard Worker * @brief Standard audio pre-processing implementation. 34*89c4ff92SAndroid Build Coastguard Worker * 35*89c4ff92SAndroid Build Coastguard Worker * Preprocesses and prepares the data for inference by 36*89c4ff92SAndroid Build Coastguard Worker * extracting the MFCC features. 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker * @param[in] audio - the raw audio data 39*89c4ff92SAndroid Build Coastguard Worker * @param[out] preprocessor - the preprocessor object, which handles the data preparation 40*89c4ff92SAndroid Build Coastguard Worker */ 41*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> PreProcessing(std::vector<float>& audio); 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker int getInputSamplesSize(); 44*89c4ff92SAndroid Build Coastguard Worker int getSlidingWindowOffset(); 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself 47*89c4ff92SAndroid Build Coastguard Worker // Will need to be refactored so that hard coded values are not defined outside of model settings 48*89c4ff92SAndroid Build Coastguard Worker int SLIDING_WINDOW_OFFSET; 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker /** 51*89c4ff92SAndroid Build Coastguard Worker * @brief Executes inference 52*89c4ff92SAndroid Build Coastguard Worker * 53*89c4ff92SAndroid Build Coastguard Worker * Calls inference runner provided during instance construction. 54*89c4ff92SAndroid Build Coastguard Worker * 55*89c4ff92SAndroid Build Coastguard Worker * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor. 56*89c4ff92SAndroid Build Coastguard Worker * @param[out] result - raw inference results. 57*89c4ff92SAndroid Build Coastguard Worker */ 58*89c4ff92SAndroid Build Coastguard Worker template<typename T> Inference(const std::vector<T> & preprocessedData,common::InferenceResults<int8_t> & result)59*89c4ff92SAndroid Build Coastguard Worker void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result) 60*89c4ff92SAndroid Build Coastguard Worker { 61*89c4ff92SAndroid Build Coastguard Worker size_t data_bytes = sizeof(T) * preprocessedData.size(); 62*89c4ff92SAndroid Build Coastguard Worker m_executor->Run(preprocessedData.data(), data_bytes, result); 63*89c4ff92SAndroid Build Coastguard Worker } 64*89c4ff92SAndroid Build Coastguard Worker 65*89c4ff92SAndroid Build Coastguard Worker /** 66*89c4ff92SAndroid Build Coastguard Worker * @brief Standard inference results post-processing implementation. 67*89c4ff92SAndroid Build Coastguard Worker * 68*89c4ff92SAndroid Build Coastguard Worker * Decodes inference results using decoder provided during construction. 69*89c4ff92SAndroid Build Coastguard Worker * 70*89c4ff92SAndroid Build Coastguard Worker * @param[in] inferenceResult - inference results to be decoded. 71*89c4ff92SAndroid Build Coastguard Worker * @param[in] isFirstWindow - for checking if this is the first window of the sliding window. 72*89c4ff92SAndroid Build Coastguard Worker * @param[in] isLastWindow - for checking if this is the last window of the sliding window. 73*89c4ff92SAndroid Build Coastguard Worker * @param[in] currentRContext - the right context of the output text. To be output if it is the last window. 74*89c4ff92SAndroid Build Coastguard Worker */ 75*89c4ff92SAndroid Build Coastguard Worker template<typename T> PostProcessing(common::InferenceResults<int8_t> & inferenceResult,bool & isFirstWindow,bool isLastWindow,std::string currentRContext)76*89c4ff92SAndroid Build Coastguard Worker void PostProcessing(common::InferenceResults<int8_t>& inferenceResult, 77*89c4ff92SAndroid Build Coastguard Worker bool& isFirstWindow, 78*89c4ff92SAndroid Build Coastguard Worker bool isLastWindow, 79*89c4ff92SAndroid Build Coastguard Worker std::string currentRContext) 80*89c4ff92SAndroid Build Coastguard Worker { 81*89c4ff92SAndroid Build Coastguard Worker int rowLength = 29; 82*89c4ff92SAndroid Build Coastguard Worker int middleContextStart = 49; 83*89c4ff92SAndroid Build Coastguard Worker int middleContextEnd = 99; 84*89c4ff92SAndroid Build Coastguard Worker int leftContextStart = 0; 85*89c4ff92SAndroid Build Coastguard Worker int rightContextStart = 100; 86*89c4ff92SAndroid Build Coastguard Worker int rightContextEnd = 148; 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker std::vector<T> contextToProcess; 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker // If isFirstWindow we keep the left context of the output 91*89c4ff92SAndroid Build Coastguard Worker if (isFirstWindow) 92*89c4ff92SAndroid Build Coastguard Worker { 93*89c4ff92SAndroid Build Coastguard Worker std::vector<T> chunk(&inferenceResult[0][leftContextStart], 94*89c4ff92SAndroid Build Coastguard Worker &inferenceResult[0][middleContextEnd * rowLength]); 95*89c4ff92SAndroid Build Coastguard Worker contextToProcess = chunk; 96*89c4ff92SAndroid Build Coastguard Worker } 97*89c4ff92SAndroid Build Coastguard Worker else 98*89c4ff92SAndroid Build Coastguard Worker { 99*89c4ff92SAndroid Build Coastguard Worker // Else we only keep the middle context of the output 100*89c4ff92SAndroid Build Coastguard Worker std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength], 101*89c4ff92SAndroid Build Coastguard Worker &inferenceResult[0][middleContextEnd * rowLength]); 102*89c4ff92SAndroid Build Coastguard Worker contextToProcess = chunk; 103*89c4ff92SAndroid Build Coastguard Worker } 104*89c4ff92SAndroid Build Coastguard Worker std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess); 105*89c4ff92SAndroid Build Coastguard Worker isFirstWindow = false; 106*89c4ff92SAndroid Build Coastguard Worker std::cout << output << std::flush; 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker // If this is the last window, we print the right context of the output 109*89c4ff92SAndroid Build Coastguard Worker if (isLastWindow) 110*89c4ff92SAndroid Build Coastguard Worker { 111*89c4ff92SAndroid Build Coastguard Worker std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength], 112*89c4ff92SAndroid Build Coastguard Worker &inferenceResult[0][rightContextEnd * rowLength]); 113*89c4ff92SAndroid Build Coastguard Worker currentRContext = this->m_decoder->DecodeOutput(rContext); 114*89c4ff92SAndroid Build Coastguard Worker std::cout << currentRContext << std::endl; 115*89c4ff92SAndroid Build Coastguard Worker } 116*89c4ff92SAndroid Build Coastguard Worker } 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker protected: 119*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor; 120*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Decoder> m_decoder; 121*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor; 122*89c4ff92SAndroid Build Coastguard Worker }; 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>; 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker /** 127*89c4ff92SAndroid Build Coastguard Worker * Constructs speech recognition pipeline based on configuration provided. 128*89c4ff92SAndroid Build Coastguard Worker * 129*89c4ff92SAndroid Build Coastguard Worker * @param[in] config - speech recognition pipeline configuration. 130*89c4ff92SAndroid Build Coastguard Worker * @param[in] labels - asr labels 131*89c4ff92SAndroid Build Coastguard Worker * 132*89c4ff92SAndroid Build Coastguard Worker * @return unique pointer to asr pipeline. 133*89c4ff92SAndroid Build Coastguard Worker */ 134*89c4ff92SAndroid Build Coastguard Worker IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels); 135*89c4ff92SAndroid Build Coastguard Worker 136*89c4ff92SAndroid Build Coastguard Worker } // namespace asr