1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
17
18 #include <fcntl.h>
19
20 #include "android-base/file.h"
21 #include "tensorflow_lite_support/cc/port/gmock.h"
22 #include "tensorflow_lite_support/cc/port/gtest.h"
23 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
24
25 namespace tflite {
26 namespace task {
27 namespace text {
28 namespace nlclassifier {
29
30 namespace {
31
32 using ::android::base::GetExecutableDirectory;
33 using ::testing::HasSubstr;
34 using ::tflite::support::kTfLiteSupportPayload;
35 using ::tflite::support::StatusOr;
36 using ::tflite::support::TfLiteSupportStatus;
37 using ::tflite::task::core::Category;
38 using ::tflite::task::core::LoadBinaryContent;
39
40 constexpr char kTestModelPath[] =
41 "/tensorflow_lite_support/cc/test/testdata/task/text/"
42 "test_model_nl_classifier_bert.tflite";
43
44 constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
45
46 constexpr int kMaxSeqLen = 128;
47
TEST(BertNLClassifierTest,TestNLClassifierCreationFilePath)48 TEST(BertNLClassifierTest, TestNLClassifierCreationFilePath) {
49 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
50 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
51 BertNLClassifier::CreateFromFile(test_model_path);
52 EXPECT_TRUE(classifier.ok());
53 }
54
TEST(BertNLClassifierTest,TestNLClassifierCreationBinary)55 TEST(BertNLClassifierTest, TestNLClassifierCreationBinary) {
56 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
57 std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
58 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
59 BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
60 EXPECT_TRUE(classifier.ok());
61 }
62
TEST(BertNLClassifierTest,TestNLClassifierCreationFailure)63 TEST(BertNLClassifierTest, TestNLClassifierCreationFailure) {
64 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
65 BertNLClassifier::CreateFromFile(kInvalidModelPath);
66
67 EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
68 EXPECT_THAT(classifier.status().message(),
69 HasSubstr("Unable to open file at i/do/not/exist.tflite"));
70 EXPECT_THAT(classifier.status().GetPayload(kTfLiteSupportPayload),
71 testing::Optional(absl::Cord(
72 absl::StrCat(TfLiteSupportStatus::kFileNotFoundError))));
73 }
74
GetCategoryWithClassName(const std::string & class_name,std::vector<Category> & categories)75 Category* GetCategoryWithClassName(const std::string& class_name,
76 std::vector<Category>& categories) {
77 for (Category& category : categories) {
78 if (category.class_name == class_name) {
79 return &category;
80 }
81 }
82 return nullptr;
83 }
84
verify_classifier(std::unique_ptr<BertNLClassifier> classifier,bool verify_positive)85 void verify_classifier(std::unique_ptr<BertNLClassifier> classifier,
86 bool verify_positive) {
87 if (verify_positive) {
88 tflite::support::StatusOr<std::vector<core::Category>> results =
89 classifier->ClassifyText("unflinchingly bleak and desperate");
90
91 EXPECT_TRUE(results.ok());
92 EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score,
93 GetCategoryWithClassName("positive", results.value())->score);
94 } else {
95 tflite::support::StatusOr<std::vector<core::Category>> results =
96 classifier->ClassifyText("it's a charming and often affecting journey");
97
98 EXPECT_TRUE(results.ok());
99 EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
100 GetCategoryWithClassName("negative", results.value())->score);
101 }
102 }
103
TEST(BertNLClassifierTest,TestNLClassifier_ClassifyNegative)104 TEST(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) {
105 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
106 std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
107 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
108 BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
109 EXPECT_TRUE(classifier.ok());
110
111 verify_classifier(std::move(*classifier), false);
112 }
113
TEST(BertNLClassifierTest,TestNLClassifier_ClassifyPositive)114 TEST(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) {
115 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
116 std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
117 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
118 BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
119 EXPECT_TRUE(classifier.ok());
120
121 verify_classifier(std::move(*classifier), true);
122 }
123
TEST(BertNLClassifierTest,TestNLClassifierFd_ClassifyPositive)124 TEST(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) {
125 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
126 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
127 BertNLClassifier::CreateFromFd(open(test_model_path.c_str(), O_RDONLY));
128 EXPECT_TRUE(classifier.ok());
129
130 verify_classifier(std::move(*classifier), false);
131 }
132
TEST(BertNLClassifierTest,TestNLClassifierFd_ClassifyNegative)133 TEST(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) {
134 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
135 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
136 BertNLClassifier::CreateFromFd(open(test_model_path.c_str(), O_RDONLY));
137 EXPECT_TRUE(classifier.ok());
138
139 verify_classifier(std::move(*classifier), true);
140 }
141
142 // BertNLClassifier limits the input sequence to kMaxSeqLen, test when input is
143 // longer than this the classifier still works correctly.
TEST(BertNLClassifierTest,TestNLClassifier_ClassifyLongPositive_notOOB)144 TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
145 std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
146 std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
147 std::stringstream ss_for_positive_review;
148 ss_for_positive_review
149 << "it's a charming and often affecting journey and this is a long";
150 for (int i = 0; i < kMaxSeqLen; ++i) {
151 ss_for_positive_review << " long";
152 }
153 ss_for_positive_review << " movie review";
154 StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
155 BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
156 EXPECT_TRUE(classifier.ok());
157
158 tflite::support::StatusOr<std::vector<core::Category>> results =
159 classifier.value()->ClassifyText(ss_for_positive_review.str());
160
161 EXPECT_TRUE(results.ok());
162 EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
163 GetCategoryWithClassName("negative", results.value())->score);
164 }
165
166 } // namespace
167
168 } // namespace nlclassifier
169 } // namespace text
170 } // namespace task
171 } // namespace tflite
172