xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/lang-id_jni.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/lang-id_jni.h"
18 
19 #include <jni.h>
20 
21 #include <type_traits>
22 #include <vector>
23 
24 #include "lang_id/lang-id-wrapper.h"
25 #include "utils/base/logging.h"
26 #include "utils/java/jni-helper.h"
27 #include "lang_id/fb_model/lang-id-from-fb.h"
28 #include "lang_id/lang-id.h"
29 
30 using libtextclassifier3::JniHelper;
31 using libtextclassifier3::JStringToUtf8String;
32 using libtextclassifier3::ScopedLocalRef;
33 using libtextclassifier3::StatusOr;
34 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
35 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
36 using libtextclassifier3::mobile::lang_id::LangId;
37 using libtextclassifier3::mobile::lang_id::LangIdResult;
38 
39 namespace {
40 
LangIdResultToJObjectArray(JNIEnv * env,const std::vector<std::pair<std::string,float>> & lang_id_predictions)41 StatusOr<ScopedLocalRef<jobjectArray>> LangIdResultToJObjectArray(
42     JNIEnv* env,
43     const std::vector<std::pair<std::string, float>>& lang_id_predictions) {
44   TC3_ASSIGN_OR_RETURN(
45       const ScopedLocalRef<jclass> result_class,
46       JniHelper::FindClass(
47           env, TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR "$LanguageResult"));
48 
49   TC3_ASSIGN_OR_RETURN(const jmethodID result_class_constructor,
50                        JniHelper::GetMethodID(env, result_class.get(), "<init>",
51                                               "(Ljava/lang/String;F)V"));
52   TC3_ASSIGN_OR_RETURN(
53       ScopedLocalRef<jobjectArray> results,
54       JniHelper::NewObjectArray(env, lang_id_predictions.size(),
55                                 result_class.get(), nullptr));
56   for (int i = 0; i < lang_id_predictions.size(); i++) {
57     TC3_ASSIGN_OR_RETURN(
58         const ScopedLocalRef<jstring> predicted_language,
59         JniHelper::NewStringUTF(env, lang_id_predictions[i].first.c_str()));
60     TC3_ASSIGN_OR_RETURN(
61         const ScopedLocalRef<jobject> result,
62         JniHelper::NewObject(
63             env, result_class.get(), result_class_constructor,
64             predicted_language.get(),
65             static_cast<jfloat>(lang_id_predictions[i].second)));
66     JniHelper::SetObjectArrayElement(env, results.get(), i, result.get());
67   }
68   return results;
69 }
70 
GetNoiseThreshold(const LangId & model)71 float GetNoiseThreshold(const LangId& model) {
72   return model.GetFloatProperty("text_classifier_langid_noise_threshold", -1.0);
73 }
74 }  // namespace
75 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNew)76 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
77 (JNIEnv* env, jobject clazz, jint fd) {
78   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
79   if (!lang_id->is_valid()) {
80     return reinterpret_cast<jlong>(nullptr);
81   }
82   return reinterpret_cast<jlong>(lang_id.release());
83 }
84 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewFromPath)85 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
86 (JNIEnv* env, jobject clazz, jstring path) {
87   TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
88                          JStringToUtf8String(env, path));
89   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
90   if (!lang_id->is_valid()) {
91     return reinterpret_cast<jlong>(nullptr);
92   }
93   return reinterpret_cast<jlong>(lang_id.release());
94 }
95 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewWithOffset)96 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
97 (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
98   std::unique_ptr<LangId> lang_id =
99       GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
100   if (!lang_id->is_valid()) {
101     return reinterpret_cast<jlong>(nullptr);
102   }
103   return reinterpret_cast<jlong>(lang_id.release());
104 }
105 
TC3_JNI_METHOD(jobjectArray,TC3_LANG_ID_CLASS_NAME,nativeDetectLanguages)106 TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
107 (JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
108   LangId* model = reinterpret_cast<LangId*>(ptr);
109   if (!model) {
110     return nullptr;
111   }
112 
113   TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str,
114                             JStringToUtf8String(env, text));
115 
116   const std::vector<std::pair<std::string, float>>& prediction_results =
117       libtextclassifier3::langid::GetPredictions(model, text_str);
118 
119   TC3_ASSIGN_OR_RETURN_NULL(
120       ScopedLocalRef<jobjectArray> results,
121       LangIdResultToJObjectArray(env, prediction_results));
122   return results.release();
123 }
124 
TC3_JNI_METHOD(void,TC3_LANG_ID_CLASS_NAME,nativeClose)125 TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
126 (JNIEnv* env, jobject thiz, jlong ptr) {
127   if (!ptr) {
128     TC3_LOG(ERROR) << "Trying to close null LangId.";
129     return;
130   }
131   LangId* model = reinterpret_cast<LangId*>(ptr);
132   delete model;
133 }
134 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersion)135 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
136 (JNIEnv* env, jobject thiz, jlong ptr) {
137   if (!ptr) {
138     return -1;
139   }
140   LangId* model = reinterpret_cast<LangId*>(ptr);
141   return model->GetModelVersion();
142 }
143 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionFromFd)144 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
145 (JNIEnv* env, jobject clazz, jint fd) {
146   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
147   if (!lang_id->is_valid()) {
148     return -1;
149   }
150   return lang_id->GetModelVersion();
151 }
152 
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdThreshold)153 TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
154 (JNIEnv* env, jobject thizz, jlong ptr) {
155   if (!ptr) {
156     return -1.0;
157   }
158   LangId* model = reinterpret_cast<LangId*>(ptr);
159   return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
160 }
161 
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdNoiseThreshold)162 TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
163 (JNIEnv* env, jobject thizz, jlong ptr) {
164   if (!ptr) {
165     return -1.0;
166   }
167   LangId* model = reinterpret_cast<LangId*>(ptr);
168   return GetNoiseThreshold(*model);
169 }
170 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetMinTextSizeInBytes)171 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
172 (JNIEnv* env, jobject thizz, jlong ptr) {
173   if (!ptr) {
174     return 0;
175   }
176   LangId* model = reinterpret_cast<LangId*>(ptr);
177   return model->GetFloatProperty("min_text_size_in_bytes", 0);
178 }
179 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionWithOffset)180 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
181 (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
182   std::unique_ptr<LangId> lang_id =
183       GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
184   if (!lang_id->is_valid()) {
185     return -1;
186   }
187   return lang_id->GetModelVersion();
188 }
189