xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/lang-id_jni.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id_jni.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <jni.h>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include <type_traits>
22*993b0882SAndroid Build Coastguard Worker #include <vector>
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id-wrapper.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/java/jni-helper.h"
27*993b0882SAndroid Build Coastguard Worker #include "lang_id/fb_model/lang-id-from-fb.h"
28*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id.h"
29*993b0882SAndroid Build Coastguard Worker 
30*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::JniHelper;
31*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::JStringToUtf8String;
32*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::ScopedLocalRef;
33*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::StatusOr;
34*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
35*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
36*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::LangId;
37*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::LangIdResult;
38*993b0882SAndroid Build Coastguard Worker 
39*993b0882SAndroid Build Coastguard Worker namespace {
40*993b0882SAndroid Build Coastguard Worker 
LangIdResultToJObjectArray(JNIEnv * env,const std::vector<std::pair<std::string,float>> & lang_id_predictions)41*993b0882SAndroid Build Coastguard Worker StatusOr<ScopedLocalRef<jobjectArray>> LangIdResultToJObjectArray(
42*993b0882SAndroid Build Coastguard Worker     JNIEnv* env,
43*993b0882SAndroid Build Coastguard Worker     const std::vector<std::pair<std::string, float>>& lang_id_predictions) {
44*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN(
45*993b0882SAndroid Build Coastguard Worker       const ScopedLocalRef<jclass> result_class,
46*993b0882SAndroid Build Coastguard Worker       JniHelper::FindClass(
47*993b0882SAndroid Build Coastguard Worker           env, TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR "$LanguageResult"));
48*993b0882SAndroid Build Coastguard Worker 
49*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN(const jmethodID result_class_constructor,
50*993b0882SAndroid Build Coastguard Worker                        JniHelper::GetMethodID(env, result_class.get(), "<init>",
51*993b0882SAndroid Build Coastguard Worker                                               "(Ljava/lang/String;F)V"));
52*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN(
53*993b0882SAndroid Build Coastguard Worker       ScopedLocalRef<jobjectArray> results,
54*993b0882SAndroid Build Coastguard Worker       JniHelper::NewObjectArray(env, lang_id_predictions.size(),
55*993b0882SAndroid Build Coastguard Worker                                 result_class.get(), nullptr));
56*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < lang_id_predictions.size(); i++) {
57*993b0882SAndroid Build Coastguard Worker     TC3_ASSIGN_OR_RETURN(
58*993b0882SAndroid Build Coastguard Worker         const ScopedLocalRef<jstring> predicted_language,
59*993b0882SAndroid Build Coastguard Worker         JniHelper::NewStringUTF(env, lang_id_predictions[i].first.c_str()));
60*993b0882SAndroid Build Coastguard Worker     TC3_ASSIGN_OR_RETURN(
61*993b0882SAndroid Build Coastguard Worker         const ScopedLocalRef<jobject> result,
62*993b0882SAndroid Build Coastguard Worker         JniHelper::NewObject(
63*993b0882SAndroid Build Coastguard Worker             env, result_class.get(), result_class_constructor,
64*993b0882SAndroid Build Coastguard Worker             predicted_language.get(),
65*993b0882SAndroid Build Coastguard Worker             static_cast<jfloat>(lang_id_predictions[i].second)));
66*993b0882SAndroid Build Coastguard Worker     JniHelper::SetObjectArrayElement(env, results.get(), i, result.get());
67*993b0882SAndroid Build Coastguard Worker   }
68*993b0882SAndroid Build Coastguard Worker   return results;
69*993b0882SAndroid Build Coastguard Worker }
70*993b0882SAndroid Build Coastguard Worker 
GetNoiseThreshold(const LangId & model)71*993b0882SAndroid Build Coastguard Worker float GetNoiseThreshold(const LangId& model) {
72*993b0882SAndroid Build Coastguard Worker   return model.GetFloatProperty("text_classifier_langid_noise_threshold", -1.0);
73*993b0882SAndroid Build Coastguard Worker }
74*993b0882SAndroid Build Coastguard Worker }  // namespace
75*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNew)76*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
77*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd) {
78*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
79*993b0882SAndroid Build Coastguard Worker   if (!lang_id->is_valid()) {
80*993b0882SAndroid Build Coastguard Worker     return reinterpret_cast<jlong>(nullptr);
81*993b0882SAndroid Build Coastguard Worker   }
82*993b0882SAndroid Build Coastguard Worker   return reinterpret_cast<jlong>(lang_id.release());
83*993b0882SAndroid Build Coastguard Worker }
84*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewFromPath)85*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
86*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jstring path) {
87*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
88*993b0882SAndroid Build Coastguard Worker                          JStringToUtf8String(env, path));
89*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
90*993b0882SAndroid Build Coastguard Worker   if (!lang_id->is_valid()) {
91*993b0882SAndroid Build Coastguard Worker     return reinterpret_cast<jlong>(nullptr);
92*993b0882SAndroid Build Coastguard Worker   }
93*993b0882SAndroid Build Coastguard Worker   return reinterpret_cast<jlong>(lang_id.release());
94*993b0882SAndroid Build Coastguard Worker }
95*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewWithOffset)96*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
97*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
98*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<LangId> lang_id =
99*993b0882SAndroid Build Coastguard Worker       GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
100*993b0882SAndroid Build Coastguard Worker   if (!lang_id->is_valid()) {
101*993b0882SAndroid Build Coastguard Worker     return reinterpret_cast<jlong>(nullptr);
102*993b0882SAndroid Build Coastguard Worker   }
103*993b0882SAndroid Build Coastguard Worker   return reinterpret_cast<jlong>(lang_id.release());
104*993b0882SAndroid Build Coastguard Worker }
105*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jobjectArray,TC3_LANG_ID_CLASS_NAME,nativeDetectLanguages)106*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
107*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
108*993b0882SAndroid Build Coastguard Worker   LangId* model = reinterpret_cast<LangId*>(ptr);
109*993b0882SAndroid Build Coastguard Worker   if (!model) {
110*993b0882SAndroid Build Coastguard Worker     return nullptr;
111*993b0882SAndroid Build Coastguard Worker   }
112*993b0882SAndroid Build Coastguard Worker 
113*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str,
114*993b0882SAndroid Build Coastguard Worker                             JStringToUtf8String(env, text));
115*993b0882SAndroid Build Coastguard Worker 
116*993b0882SAndroid Build Coastguard Worker   const std::vector<std::pair<std::string, float>>& prediction_results =
117*993b0882SAndroid Build Coastguard Worker       libtextclassifier3::langid::GetPredictions(model, text_str);
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN_NULL(
120*993b0882SAndroid Build Coastguard Worker       ScopedLocalRef<jobjectArray> results,
121*993b0882SAndroid Build Coastguard Worker       LangIdResultToJObjectArray(env, prediction_results));
122*993b0882SAndroid Build Coastguard Worker   return results.release();
123*993b0882SAndroid Build Coastguard Worker }
124*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(void,TC3_LANG_ID_CLASS_NAME,nativeClose)125*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
126*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thiz, jlong ptr) {
127*993b0882SAndroid Build Coastguard Worker   if (!ptr) {
128*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Trying to close null LangId.";
129*993b0882SAndroid Build Coastguard Worker     return;
130*993b0882SAndroid Build Coastguard Worker   }
131*993b0882SAndroid Build Coastguard Worker   LangId* model = reinterpret_cast<LangId*>(ptr);
132*993b0882SAndroid Build Coastguard Worker   delete model;
133*993b0882SAndroid Build Coastguard Worker }
134*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersion)135*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
136*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thiz, jlong ptr) {
137*993b0882SAndroid Build Coastguard Worker   if (!ptr) {
138*993b0882SAndroid Build Coastguard Worker     return -1;
139*993b0882SAndroid Build Coastguard Worker   }
140*993b0882SAndroid Build Coastguard Worker   LangId* model = reinterpret_cast<LangId*>(ptr);
141*993b0882SAndroid Build Coastguard Worker   return model->GetModelVersion();
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionFromFd)144*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
145*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd) {
146*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
147*993b0882SAndroid Build Coastguard Worker   if (!lang_id->is_valid()) {
148*993b0882SAndroid Build Coastguard Worker     return -1;
149*993b0882SAndroid Build Coastguard Worker   }
150*993b0882SAndroid Build Coastguard Worker   return lang_id->GetModelVersion();
151*993b0882SAndroid Build Coastguard Worker }
152*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdThreshold)153*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
154*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thizz, jlong ptr) {
155*993b0882SAndroid Build Coastguard Worker   if (!ptr) {
156*993b0882SAndroid Build Coastguard Worker     return -1.0;
157*993b0882SAndroid Build Coastguard Worker   }
158*993b0882SAndroid Build Coastguard Worker   LangId* model = reinterpret_cast<LangId*>(ptr);
159*993b0882SAndroid Build Coastguard Worker   return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
160*993b0882SAndroid Build Coastguard Worker }
161*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdNoiseThreshold)162*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
163*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thizz, jlong ptr) {
164*993b0882SAndroid Build Coastguard Worker   if (!ptr) {
165*993b0882SAndroid Build Coastguard Worker     return -1.0;
166*993b0882SAndroid Build Coastguard Worker   }
167*993b0882SAndroid Build Coastguard Worker   LangId* model = reinterpret_cast<LangId*>(ptr);
168*993b0882SAndroid Build Coastguard Worker   return GetNoiseThreshold(*model);
169*993b0882SAndroid Build Coastguard Worker }
170*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetMinTextSizeInBytes)171*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
172*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thizz, jlong ptr) {
173*993b0882SAndroid Build Coastguard Worker   if (!ptr) {
174*993b0882SAndroid Build Coastguard Worker     return 0;
175*993b0882SAndroid Build Coastguard Worker   }
176*993b0882SAndroid Build Coastguard Worker   LangId* model = reinterpret_cast<LangId*>(ptr);
177*993b0882SAndroid Build Coastguard Worker   return model->GetFloatProperty("min_text_size_in_bytes", 0);
178*993b0882SAndroid Build Coastguard Worker }
179*993b0882SAndroid Build Coastguard Worker 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionWithOffset)180*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
181*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
182*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<LangId> lang_id =
183*993b0882SAndroid Build Coastguard Worker       GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
184*993b0882SAndroid Build Coastguard Worker   if (!lang_id->is_valid()) {
185*993b0882SAndroid Build Coastguard Worker     return -1;
186*993b0882SAndroid Build Coastguard Worker   }
187*993b0882SAndroid Build Coastguard Worker   return lang_id->GetModelVersion();
188*993b0882SAndroid Build Coastguard Worker }
189