1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/lang-id_jni.h"
18
19 #include <jni.h>
20
21 #include <type_traits>
22 #include <vector>
23
24 #include "lang_id/lang-id-wrapper.h"
25 #include "utils/base/logging.h"
26 #include "utils/java/jni-helper.h"
27 #include "lang_id/fb_model/lang-id-from-fb.h"
28 #include "lang_id/lang-id.h"
29
30 using libtextclassifier3::JniHelper;
31 using libtextclassifier3::JStringToUtf8String;
32 using libtextclassifier3::ScopedLocalRef;
33 using libtextclassifier3::StatusOr;
34 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
35 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
36 using libtextclassifier3::mobile::lang_id::LangId;
37 using libtextclassifier3::mobile::lang_id::LangIdResult;
38
39 namespace {
40
LangIdResultToJObjectArray(JNIEnv * env,const std::vector<std::pair<std::string,float>> & lang_id_predictions)41 StatusOr<ScopedLocalRef<jobjectArray>> LangIdResultToJObjectArray(
42 JNIEnv* env,
43 const std::vector<std::pair<std::string, float>>& lang_id_predictions) {
44 TC3_ASSIGN_OR_RETURN(
45 const ScopedLocalRef<jclass> result_class,
46 JniHelper::FindClass(
47 env, TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR "$LanguageResult"));
48
49 TC3_ASSIGN_OR_RETURN(const jmethodID result_class_constructor,
50 JniHelper::GetMethodID(env, result_class.get(), "<init>",
51 "(Ljava/lang/String;F)V"));
52 TC3_ASSIGN_OR_RETURN(
53 ScopedLocalRef<jobjectArray> results,
54 JniHelper::NewObjectArray(env, lang_id_predictions.size(),
55 result_class.get(), nullptr));
56 for (int i = 0; i < lang_id_predictions.size(); i++) {
57 TC3_ASSIGN_OR_RETURN(
58 const ScopedLocalRef<jstring> predicted_language,
59 JniHelper::NewStringUTF(env, lang_id_predictions[i].first.c_str()));
60 TC3_ASSIGN_OR_RETURN(
61 const ScopedLocalRef<jobject> result,
62 JniHelper::NewObject(
63 env, result_class.get(), result_class_constructor,
64 predicted_language.get(),
65 static_cast<jfloat>(lang_id_predictions[i].second)));
66 JniHelper::SetObjectArrayElement(env, results.get(), i, result.get());
67 }
68 return results;
69 }
70
GetNoiseThreshold(const LangId & model)71 float GetNoiseThreshold(const LangId& model) {
72 return model.GetFloatProperty("text_classifier_langid_noise_threshold", -1.0);
73 }
74 } // namespace
75
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNew)76 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
77 (JNIEnv* env, jobject clazz, jint fd) {
78 std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
79 if (!lang_id->is_valid()) {
80 return reinterpret_cast<jlong>(nullptr);
81 }
82 return reinterpret_cast<jlong>(lang_id.release());
83 }
84
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewFromPath)85 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
86 (JNIEnv* env, jobject clazz, jstring path) {
87 TC3_ASSIGN_OR_RETURN_0(const std::string path_str,
88 JStringToUtf8String(env, path));
89 std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
90 if (!lang_id->is_valid()) {
91 return reinterpret_cast<jlong>(nullptr);
92 }
93 return reinterpret_cast<jlong>(lang_id.release());
94 }
95
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewWithOffset)96 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewWithOffset)
97 (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
98 std::unique_ptr<LangId> lang_id =
99 GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
100 if (!lang_id->is_valid()) {
101 return reinterpret_cast<jlong>(nullptr);
102 }
103 return reinterpret_cast<jlong>(lang_id.release());
104 }
105
TC3_JNI_METHOD(jobjectArray,TC3_LANG_ID_CLASS_NAME,nativeDetectLanguages)106 TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
107 (JNIEnv* env, jobject thiz, jlong ptr, jstring text) {
108 LangId* model = reinterpret_cast<LangId*>(ptr);
109 if (!model) {
110 return nullptr;
111 }
112
113 TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str,
114 JStringToUtf8String(env, text));
115
116 const std::vector<std::pair<std::string, float>>& prediction_results =
117 libtextclassifier3::langid::GetPredictions(model, text_str);
118
119 TC3_ASSIGN_OR_RETURN_NULL(
120 ScopedLocalRef<jobjectArray> results,
121 LangIdResultToJObjectArray(env, prediction_results));
122 return results.release();
123 }
124
TC3_JNI_METHOD(void,TC3_LANG_ID_CLASS_NAME,nativeClose)125 TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
126 (JNIEnv* env, jobject thiz, jlong ptr) {
127 if (!ptr) {
128 TC3_LOG(ERROR) << "Trying to close null LangId.";
129 return;
130 }
131 LangId* model = reinterpret_cast<LangId*>(ptr);
132 delete model;
133 }
134
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersion)135 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
136 (JNIEnv* env, jobject thiz, jlong ptr) {
137 if (!ptr) {
138 return -1;
139 }
140 LangId* model = reinterpret_cast<LangId*>(ptr);
141 return model->GetModelVersion();
142 }
143
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionFromFd)144 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
145 (JNIEnv* env, jobject clazz, jint fd) {
146 std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
147 if (!lang_id->is_valid()) {
148 return -1;
149 }
150 return lang_id->GetModelVersion();
151 }
152
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdThreshold)153 TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
154 (JNIEnv* env, jobject thizz, jlong ptr) {
155 if (!ptr) {
156 return -1.0;
157 }
158 LangId* model = reinterpret_cast<LangId*>(ptr);
159 return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
160 }
161
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdNoiseThreshold)162 TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
163 (JNIEnv* env, jobject thizz, jlong ptr) {
164 if (!ptr) {
165 return -1.0;
166 }
167 LangId* model = reinterpret_cast<LangId*>(ptr);
168 return GetNoiseThreshold(*model);
169 }
170
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetMinTextSizeInBytes)171 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
172 (JNIEnv* env, jobject thizz, jlong ptr) {
173 if (!ptr) {
174 return 0;
175 }
176 LangId* model = reinterpret_cast<LangId*>(ptr);
177 return model->GetFloatProperty("min_text_size_in_bytes", 0);
178 }
179
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionWithOffset)180 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
181 (JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
182 std::unique_ptr<LangId> lang_id =
183 GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
184 if (!lang_id->is_valid()) {
185 return -1;
186 }
187 return lang_id->GetModelVersion();
188 }
189