1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_ 18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_ 19 20 #include <string> 21 #include <vector> 22 23 #include "lang_id/common/embedding-feature-extractor.h" 24 #include "lang_id/common/fel/feature-extractor.h" 25 #include "lang_id/common/fel/task-context.h" 26 #include "lang_id/common/fel/workspace.h" 27 #include "lang_id/common/lite_base/attributes.h" 28 #include "absl/strings/string_view.h" 29 30 namespace libtextclassifier3 { 31 namespace mobile { 32 33 template <class EXTRACTOR, class OBJ, class... ARGS> 34 class EmbeddingFeatureInterface { 35 public: 36 // Constructs this EmbeddingFeatureInterface. 37 // 38 // |arg_prefix| is a string prefix for the TaskContext parameters, passed to 39 // |the underlying EmbeddingFeatureExtractor. EmbeddingFeatureInterface(absl::string_view arg_prefix)40 explicit EmbeddingFeatureInterface(absl::string_view arg_prefix) 41 : feature_extractor_(arg_prefix) {} 42 43 // Sets up feature extractors and flags for processing (inference). SetupForProcessing(TaskContext * context)44 SAFTM_MUST_USE_RESULT bool SetupForProcessing(TaskContext *context) { 45 return feature_extractor_.Setup(context); 46 } 47 48 // Initializes feature extractor resources for processing (inference) 49 // including requesting a workspace for caching extracted features. InitForProcessing(TaskContext * context)50 SAFTM_MUST_USE_RESULT bool InitForProcessing(TaskContext *context) { 51 if (!feature_extractor_.Init(context)) return false; 52 feature_extractor_.RequestWorkspaces(&workspace_registry_); 53 return true; 54 } 55 56 // Preprocesses *obj using the internal workspace registry. Preprocess(WorkspaceSet * workspace,OBJ * obj)57 void Preprocess(WorkspaceSet *workspace, OBJ *obj) const { 58 workspace->Reset(workspace_registry_); 59 feature_extractor_.Preprocess(workspace, obj); 60 } 61 62 // Extract features from |obj|. On return, FeatureVector features[i] 63 // contains the features for the embedding space #i. 64 // 65 // This function uses the precomputed info from |workspace|. Usage pattern: 66 // 67 // EmbeddingFeatureInterface<...> feature_interface; 68 // ... 69 // OBJ obj; 70 // WorkspaceSet workspace; 71 // feature_interface.Preprocess(&workspace, &obj); 72 // 73 // // For the same obj, but with different args: 74 // std::vector<FeatureVector> features; 75 // feature_interface.GetFeatures(obj, args, workspace, &features); 76 // 77 // This pattern is useful (more efficient) if you can pre-compute some info 78 // for the entire |obj|, which is reused by the feature extraction performed 79 // for different args. If that is not the case, you can use the simpler 80 // version GetFeaturesNoCaching below. GetFeatures(const OBJ & obj,ARGS...args,const WorkspaceSet & workspace,std::vector<FeatureVector> * features)81 void GetFeatures(const OBJ &obj, ARGS... args, const WorkspaceSet &workspace, 82 std::vector<FeatureVector> *features) const { 83 feature_extractor_.ExtractFeatures(workspace, obj, args..., features); 84 } 85 86 // Simpler version of GetFeatures(), for cases when there is no opportunity to 87 // reuse computation between feature extractions for the same |obj|, but with 88 // different |args|. Returns the extracted features. For more info, see the 89 // doc for GetFeatures(). GetFeaturesNoCaching(OBJ * obj,ARGS...args)90 std::vector<FeatureVector> GetFeaturesNoCaching(OBJ *obj, 91 ARGS... args) const { 92 // Technically, we still use a workspace, because 93 // feature_extractor_.ExtractFeatures requires one. But there is no real 94 // caching here, as we start from scratch for each call to ExtractFeatures. 95 WorkspaceSet workspace; 96 Preprocess(&workspace, obj); 97 std::vector<FeatureVector> features(NumEmbeddings()); 98 GetFeatures(*obj, args..., workspace, &features); 99 return features; 100 } 101 102 // Returns number of embedding spaces. NumEmbeddings()103 int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); } 104 105 private: 106 // Typed feature extractor for embeddings. 107 EmbeddingFeatureExtractor<EXTRACTOR, OBJ, ARGS...> feature_extractor_; 108 109 // The registry of shared workspaces in the feature extractor. 110 WorkspaceRegistry workspace_registry_; 111 }; 112 113 } // namespace mobile 114 } // namespace nlp_saft 115 116 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_INTERFACE_H_ 117