1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
17 
18 #include <fcntl.h>
19 
20 #include "android-base/file.h"
21 #include "tensorflow_lite_support/cc/port/gmock.h"
22 #include "tensorflow_lite_support/cc/port/gtest.h"
23 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
24 
25 namespace tflite {
26 namespace task {
27 namespace text {
28 namespace nlclassifier {
29 
30 namespace {
31 
32 using ::android::base::GetExecutableDirectory;
33 using ::testing::HasSubstr;
34 using ::tflite::support::kTfLiteSupportPayload;
35 using ::tflite::support::StatusOr;
36 using ::tflite::support::TfLiteSupportStatus;
37 using ::tflite::task::core::Category;
38 using ::tflite::task::core::LoadBinaryContent;
39 
40 constexpr char kTestModelPath[] =
41     "/tensorflow_lite_support/cc/test/testdata/task/text/"
42     "test_model_nl_classifier_bert.tflite";
43 
44 constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
45 
46 constexpr int kMaxSeqLen = 128;
47 
TEST(BertNLClassifierTest,TestNLClassifierCreationFilePath)48 TEST(BertNLClassifierTest, TestNLClassifierCreationFilePath) {
49   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
50   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
51       BertNLClassifier::CreateFromFile(test_model_path);
52   EXPECT_TRUE(classifier.ok());
53 }
54 
TEST(BertNLClassifierTest,TestNLClassifierCreationBinary)55 TEST(BertNLClassifierTest, TestNLClassifierCreationBinary) {
56   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
57   std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
58   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
59       BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
60   EXPECT_TRUE(classifier.ok());
61 }
62 
TEST(BertNLClassifierTest,TestNLClassifierCreationFailure)63 TEST(BertNLClassifierTest, TestNLClassifierCreationFailure) {
64   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
65       BertNLClassifier::CreateFromFile(kInvalidModelPath);
66 
67   EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
68   EXPECT_THAT(classifier.status().message(),
69               HasSubstr("Unable to open file at i/do/not/exist.tflite"));
70   EXPECT_THAT(classifier.status().GetPayload(kTfLiteSupportPayload),
71               testing::Optional(absl::Cord(
72                   absl::StrCat(TfLiteSupportStatus::kFileNotFoundError))));
73 }
74 
GetCategoryWithClassName(const std::string & class_name,std::vector<Category> & categories)75 Category* GetCategoryWithClassName(const std::string& class_name,
76                                    std::vector<Category>& categories) {
77   for (Category& category : categories) {
78     if (category.class_name == class_name) {
79       return &category;
80     }
81   }
82   return nullptr;
83 }
84 
verify_classifier(std::unique_ptr<BertNLClassifier> classifier,bool verify_positive)85 void verify_classifier(std::unique_ptr<BertNLClassifier> classifier,
86                        bool verify_positive) {
87   if (verify_positive) {
88     tflite::support::StatusOr<std::vector<core::Category>> results =
89         classifier->ClassifyText("unflinchingly bleak and desperate");
90 
91     EXPECT_TRUE(results.ok());
92     EXPECT_GT(GetCategoryWithClassName("negative", results.value())->score,
93               GetCategoryWithClassName("positive", results.value())->score);
94   } else {
95     tflite::support::StatusOr<std::vector<core::Category>> results =
96         classifier->ClassifyText("it's a charming and often affecting journey");
97 
98     EXPECT_TRUE(results.ok());
99     EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
100               GetCategoryWithClassName("negative", results.value())->score);
101   }
102 }
103 
TEST(BertNLClassifierTest,TestNLClassifier_ClassifyNegative)104 TEST(BertNLClassifierTest, TestNLClassifier_ClassifyNegative) {
105   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
106   std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
107   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
108       BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
109   EXPECT_TRUE(classifier.ok());
110 
111   verify_classifier(std::move(*classifier), false);
112 }
113 
TEST(BertNLClassifierTest,TestNLClassifier_ClassifyPositive)114 TEST(BertNLClassifierTest, TestNLClassifier_ClassifyPositive) {
115   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
116   std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
117   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
118       BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
119   EXPECT_TRUE(classifier.ok());
120 
121   verify_classifier(std::move(*classifier), true);
122 }
123 
TEST(BertNLClassifierTest,TestNLClassifierFd_ClassifyPositive)124 TEST(BertNLClassifierTest, TestNLClassifierFd_ClassifyPositive) {
125   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
126   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
127       BertNLClassifier::CreateFromFd(open(test_model_path.c_str(), O_RDONLY));
128   EXPECT_TRUE(classifier.ok());
129 
130   verify_classifier(std::move(*classifier), false);
131 }
132 
TEST(BertNLClassifierTest,TestNLClassifierFd_ClassifyNegative)133 TEST(BertNLClassifierTest, TestNLClassifierFd_ClassifyNegative) {
134   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
135   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
136       BertNLClassifier::CreateFromFd(open(test_model_path.c_str(), O_RDONLY));
137   EXPECT_TRUE(classifier.ok());
138 
139   verify_classifier(std::move(*classifier), true);
140 }
141 
142 // BertNLClassifier limits the input sequence to kMaxSeqLen, test when input is
143 // longer than this the classifier still works correctly.
TEST(BertNLClassifierTest,TestNLClassifier_ClassifyLongPositive_notOOB)144 TEST(BertNLClassifierTest, TestNLClassifier_ClassifyLongPositive_notOOB) {
145   std::string test_model_path = absl::StrCat(GetExecutableDirectory(), kTestModelPath);
146   std::string model_buffer = LoadBinaryContent(test_model_path.c_str());
147   std::stringstream ss_for_positive_review;
148   ss_for_positive_review
149       << "it's a charming and often affecting journey and this is a long";
150   for (int i = 0; i < kMaxSeqLen; ++i) {
151     ss_for_positive_review << " long";
152   }
153   ss_for_positive_review << " movie review";
154   StatusOr<std::unique_ptr<BertNLClassifier>> classifier =
155       BertNLClassifier::CreateFromBuffer(model_buffer.data(), model_buffer.size());
156   EXPECT_TRUE(classifier.ok());
157 
158   tflite::support::StatusOr<std::vector<core::Category>> results =
159       classifier.value()->ClassifyText(ss_for_positive_review.str());
160 
161   EXPECT_TRUE(results.ok());
162   EXPECT_GT(GetCategoryWithClassName("positive", results.value())->score,
163             GetCategoryWithClassName("negative", results.value())->score);
164 }
165 
166 }  // namespace
167 
168 }  // namespace nlclassifier
169 }  // namespace text
170 }  // namespace task
171 }  // namespace tflite
172