1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker #ifndef NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_ 18*993b0882SAndroid Build Coastguard Worker #define NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_ 19*993b0882SAndroid Build Coastguard Worker 20*993b0882SAndroid Build Coastguard Worker #include <string> 21*993b0882SAndroid Build Coastguard Worker #include <vector> 22*993b0882SAndroid Build Coastguard Worker 23*993b0882SAndroid Build Coastguard Worker #include "lang_id/common/embedding-network-params.h" 24*993b0882SAndroid Build Coastguard Worker 25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 26*993b0882SAndroid Build Coastguard Worker namespace mobile { 27*993b0882SAndroid Build Coastguard Worker namespace lang_id { 28*993b0882SAndroid Build Coastguard Worker 29*993b0882SAndroid Build Coastguard Worker // Interface for accessing parameters for the LangId model. 30*993b0882SAndroid Build Coastguard Worker // 31*993b0882SAndroid Build Coastguard Worker // Note: some clients prefer to include the model parameters in the binary, 32*993b0882SAndroid Build Coastguard Worker // others prefer loading them from a separate file. This file provides a common 33*993b0882SAndroid Build Coastguard Worker // interface for these alternative mechanisms. 34*993b0882SAndroid Build Coastguard Worker class ModelProvider { 35*993b0882SAndroid Build Coastguard Worker public: 36*993b0882SAndroid Build Coastguard Worker virtual ~ModelProvider() = default; 37*993b0882SAndroid Build Coastguard Worker 38*993b0882SAndroid Build Coastguard Worker // Returns true if this ModelProvider has been succesfully constructed (e.g., 39*993b0882SAndroid Build Coastguard Worker // can return false if an underlying model file could not be read). Clients 40*993b0882SAndroid Build Coastguard Worker // should not use invalid ModelProviders. is_valid()41*993b0882SAndroid Build Coastguard Worker bool is_valid() { return valid_; } 42*993b0882SAndroid Build Coastguard Worker 43*993b0882SAndroid Build Coastguard Worker // Returns the TaskContext with parameters for the LangId model. E.g., one 44*993b0882SAndroid Build Coastguard Worker // important parameter specifies the features to use. 45*993b0882SAndroid Build Coastguard Worker virtual const TaskContext *GetTaskContext() const = 0; 46*993b0882SAndroid Build Coastguard Worker 47*993b0882SAndroid Build Coastguard Worker // Returns parameters for the underlying Neurosis feed-forward neural network. 48*993b0882SAndroid Build Coastguard Worker virtual const EmbeddingNetworkParams *GetNnParams() const = 0; 49*993b0882SAndroid Build Coastguard Worker 50*993b0882SAndroid Build Coastguard Worker // Returns list of languages recognized by the model. Each element of the 51*993b0882SAndroid Build Coastguard Worker // returned vector should be a BCP-47 language code (e.g., "en", "ro", etc). 52*993b0882SAndroid Build Coastguard Worker // Language at index i from the returned vector corresponds to softmax label 53*993b0882SAndroid Build Coastguard Worker // i. 54*993b0882SAndroid Build Coastguard Worker virtual std::vector<std::string> GetLanguages() const = 0; 55*993b0882SAndroid Build Coastguard Worker 56*993b0882SAndroid Build Coastguard Worker protected: 57*993b0882SAndroid Build Coastguard Worker bool valid_ = false; 58*993b0882SAndroid Build Coastguard Worker }; 59*993b0882SAndroid Build Coastguard Worker 60*993b0882SAndroid Build Coastguard Worker } // namespace lang_id 61*993b0882SAndroid Build Coastguard Worker } // namespace mobile 62*993b0882SAndroid Build Coastguard Worker } // namespace nlp_saft 63*993b0882SAndroid Build Coastguard Worker 64*993b0882SAndroid Build Coastguard Worker #endif // NLP_SAFT_COMPONENTS_LANG_ID_MOBILE_MODEL_PROVIDER_H_ 65