1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id_jni.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <jni.h>
20*993b0882SAndroid Build Coastguard Worker
21*993b0882SAndroid Build Coastguard Worker #include <type_traits>
22*993b0882SAndroid Build Coastguard Worker #include <vector>
23*993b0882SAndroid Build Coastguard Worker
24*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id-wrapper.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/java/jni-helper.h"
27*993b0882SAndroid Build Coastguard Worker #include "lang_id/fb_model/lang-id-from-fb.h"
28*993b0882SAndroid Build Coastguard Worker #include "lang_id/lang-id.h"
29*993b0882SAndroid Build Coastguard Worker
30*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::JniHelper;
31*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::JStringToUtf8String;
32*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::ScopedLocalRef;
33*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::StatusOr;
34*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
35*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
36*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::LangId;
37*993b0882SAndroid Build Coastguard Worker using libtextclassifier3::mobile::lang_id::LangIdResult;
38*993b0882SAndroid Build Coastguard Worker
39*993b0882SAndroid Build Coastguard Worker namespace {
40*993b0882SAndroid Build Coastguard Worker
LangIdResultToJObjectArray(JNIEnv * env,const std::vector<std::pair<std::string,float>> & lang_id_predictions)41*993b0882SAndroid Build Coastguard Worker StatusOr<ScopedLocalRef<jobjectArray>> LangIdResultToJObjectArray(
42*993b0882SAndroid Build Coastguard Worker JNIEnv* env,
43*993b0882SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, float>>& lang_id_predictions) {
44*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN(
45*993b0882SAndroid Build Coastguard Worker const ScopedLocalRef<jclass> result_class,
46*993b0882SAndroid Build Coastguard Worker JniHelper::FindClass(
47*993b0882SAndroid Build Coastguard Worker env, TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR "$LanguageResult"));
48*993b0882SAndroid Build Coastguard Worker
49*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN(const jmethodID result_class_constructor,
50*993b0882SAndroid Build Coastguard Worker JniHelper::GetMethodID(env, result_class.get(), "<init>",
51*993b0882SAndroid Build Coastguard Worker "(Ljava/lang/String;F)V"));
52*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN(
53*993b0882SAndroid Build Coastguard Worker ScopedLocalRef<jobjectArray> results,
54*993b0882SAndroid Build Coastguard Worker JniHelper::NewObjectArray(env, lang_id_predictions.size(),
55*993b0882SAndroid Build Coastguard Worker result_class.get(), nullptr));
56*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < lang_id_predictions.size(); i++) {
57*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN(
58*993b0882SAndroid Build Coastguard Worker const ScopedLocalRef<jstring> predicted_language,
59*993b0882SAndroid Build Coastguard Worker JniHelper::NewStringUTF(env, lang_id_predictions[i].first.c_str()));
60*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN(
61*993b0882SAndroid Build Coastguard Worker const ScopedLocalRef<jobject> result,
62*993b0882SAndroid Build Coastguard Worker JniHelper::NewObject(
63*993b0882SAndroid Build Coastguard Worker env, result_class.get(), result_class_constructor,
64*993b0882SAndroid Build Coastguard Worker predicted_language.get(),
65*993b0882SAndroid Build Coastguard Worker static_cast<jfloat>(lang_id_predictions[i].second)));
66*993b0882SAndroid Build Coastguard Worker JniHelper::SetObjectArrayElement(env, results.get(), i, result.get());
67*993b0882SAndroid Build Coastguard Worker }
68*993b0882SAndroid Build Coastguard Worker return results;
69*993b0882SAndroid Build Coastguard Worker }
70*993b0882SAndroid Build Coastguard Worker
GetNoiseThreshold(const LangId & model)71*993b0882SAndroid Build Coastguard Worker float GetNoiseThreshold(const LangId& model) {
72*993b0882SAndroid Build Coastguard Worker return model.GetFloatProperty("text_classifier_langid_noise_threshold", -1.0);
73*993b0882SAndroid Build Coastguard Worker }
74*993b0882SAndroid Build Coastguard Worker } // namespace
75*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNew)76*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
77*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd) {
78*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
79*993b0882SAndroid Build Coastguard Worker if (!lang_id->is_valid()) {
80*993b0882SAndroid Build Coastguard Worker return reinterpret_cast<jlong>(nullptr);
81*993b0882SAndroid Build Coastguard Worker }
82*993b0882SAndroid Build Coastguard Worker return reinterpret_cast<jlong>(lang_id.release());
83*993b0882SAndroid Build Coastguard Worker }
84*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewFromPath)85*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
86*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jstring path) {
87*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
88*993b0882SAndroid Build Coastguard Worker JStringToUtf8String(env, path));
89*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
90*993b0882SAndroid Build Coastguard Worker if (!lang_id->is_valid()) {
91*993b0882SAndroid Build Coastguard Worker return reinterpret_cast<jlong>(nullptr);
92*993b0882SAndroid Build Coastguard Worker }
93*993b0882SAndroid Build Coastguard Worker return reinterpret_cast<jlong>(lang_id.release());
94*993b0882SAndroid Build Coastguard Worker }
95*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewWithOffset)96*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
97*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
98*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LangId> lang_id =
99*993b0882SAndroid Build Coastguard Worker GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
100*993b0882SAndroid Build Coastguard Worker if (!lang_id->is_valid()) {
101*993b0882SAndroid Build Coastguard Worker return reinterpret_cast<jlong>(nullptr);
102*993b0882SAndroid Build Coastguard Worker }
103*993b0882SAndroid Build Coastguard Worker return reinterpret_cast<jlong>(lang_id.release());
104*993b0882SAndroid Build Coastguard Worker }
105*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jobjectArray,TC3_LANG_ID_CLASS_NAME,nativeDetectLanguages)106*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
107*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
108*993b0882SAndroid Build Coastguard Worker LangId* model = reinterpret_cast<LangId*>(ptr);
109*993b0882SAndroid Build Coastguard Worker if (!model) {
110*993b0882SAndroid Build Coastguard Worker return nullptr;
111*993b0882SAndroid Build Coastguard Worker }
112*993b0882SAndroid Build Coastguard Worker
113*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str,
114*993b0882SAndroid Build Coastguard Worker JStringToUtf8String(env, text));
115*993b0882SAndroid Build Coastguard Worker
116*993b0882SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, float>>& prediction_results =
117*993b0882SAndroid Build Coastguard Worker libtextclassifier3::langid::GetPredictions(model, text_str);
118*993b0882SAndroid Build Coastguard Worker
119*993b0882SAndroid Build Coastguard Worker TC3_ASSIGN_OR_RETURN_NULL(
120*993b0882SAndroid Build Coastguard Worker ScopedLocalRef<jobjectArray> results,
121*993b0882SAndroid Build Coastguard Worker LangIdResultToJObjectArray(env, prediction_results));
122*993b0882SAndroid Build Coastguard Worker return results.release();
123*993b0882SAndroid Build Coastguard Worker }
124*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(void,TC3_LANG_ID_CLASS_NAME,nativeClose)125*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
126*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thiz, jlong ptr) {
127*993b0882SAndroid Build Coastguard Worker if (!ptr) {
128*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Trying to close null LangId.";
129*993b0882SAndroid Build Coastguard Worker return;
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker LangId* model = reinterpret_cast<LangId*>(ptr);
132*993b0882SAndroid Build Coastguard Worker delete model;
133*993b0882SAndroid Build Coastguard Worker }
134*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersion)135*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
136*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thiz, jlong ptr) {
137*993b0882SAndroid Build Coastguard Worker if (!ptr) {
138*993b0882SAndroid Build Coastguard Worker return -1;
139*993b0882SAndroid Build Coastguard Worker }
140*993b0882SAndroid Build Coastguard Worker LangId* model = reinterpret_cast<LangId*>(ptr);
141*993b0882SAndroid Build Coastguard Worker return model->GetModelVersion();
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionFromFd)144*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
145*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd) {
146*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
147*993b0882SAndroid Build Coastguard Worker if (!lang_id->is_valid()) {
148*993b0882SAndroid Build Coastguard Worker return -1;
149*993b0882SAndroid Build Coastguard Worker }
150*993b0882SAndroid Build Coastguard Worker return lang_id->GetModelVersion();
151*993b0882SAndroid Build Coastguard Worker }
152*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdThreshold)153*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
154*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thizz, jlong ptr) {
155*993b0882SAndroid Build Coastguard Worker if (!ptr) {
156*993b0882SAndroid Build Coastguard Worker return -1.0;
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker LangId* model = reinterpret_cast<LangId*>(ptr);
159*993b0882SAndroid Build Coastguard Worker return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
160*993b0882SAndroid Build Coastguard Worker }
161*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdNoiseThreshold)162*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
163*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thizz, jlong ptr) {
164*993b0882SAndroid Build Coastguard Worker if (!ptr) {
165*993b0882SAndroid Build Coastguard Worker return -1.0;
166*993b0882SAndroid Build Coastguard Worker }
167*993b0882SAndroid Build Coastguard Worker LangId* model = reinterpret_cast<LangId*>(ptr);
168*993b0882SAndroid Build Coastguard Worker return GetNoiseThreshold(*model);
169*993b0882SAndroid Build Coastguard Worker }
170*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetMinTextSizeInBytes)171*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
172*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject thizz, jlong ptr) {
173*993b0882SAndroid Build Coastguard Worker if (!ptr) {
174*993b0882SAndroid Build Coastguard Worker return 0;
175*993b0882SAndroid Build Coastguard Worker }
176*993b0882SAndroid Build Coastguard Worker LangId* model = reinterpret_cast<LangId*>(ptr);
177*993b0882SAndroid Build Coastguard Worker return model->GetFloatProperty("min_text_size_in_bytes", 0);
178*993b0882SAndroid Build Coastguard Worker }
179*993b0882SAndroid Build Coastguard Worker
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionWithOffset)180*993b0882SAndroid Build Coastguard Worker TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
181*993b0882SAndroid Build Coastguard Worker (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
182*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LangId> lang_id =
183*993b0882SAndroid Build Coastguard Worker GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
184*993b0882SAndroid Build Coastguard Worker if (!lang_id->is_valid()) {
185*993b0882SAndroid Build Coastguard Worker return -1;
186*993b0882SAndroid Build Coastguard Worker }
187*993b0882SAndroid Build Coastguard Worker return lang_id->GetModelVersion();
188*993b0882SAndroid Build Coastguard Worker }
189