1 /*
2 * Copyright 2018 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "aemu/base/testing/ProtobufMatchers.h"
17
18 #include <algorithm>
19 #include <regex>
20 #include <string>
21 #include <string_view>
22
23 #include "absl/log/check.h"
24 #include "gmock/gmock-matchers.h"
25 #include "gmock/gmock-more-matchers.h"
26 #include "google/protobuf/descriptor.h"
27 #include "google/protobuf/io/tokenizer.h"
28 #include "google/protobuf/message.h"
29 #include "google/protobuf/text_format.h"
30 #include "google/protobuf/util/message_differencer.h"
31
32 namespace android {
33 namespace internal {
34
35 // Utilities.
36 using google::protobuf::io::ColumnNumber;
37
38 class StringErrorCollector : public google::protobuf::io::ErrorCollector {
39 public:
StringErrorCollector(std::string * error_text)40 explicit StringErrorCollector(std::string* error_text) : error_text_(error_text) {}
41
RecordError(int line,ColumnNumber column,absl::string_view message)42 void RecordError(int line, ColumnNumber column, absl::string_view message) override {
43 std::ostringstream ss;
44 ss << "ERROR: " << line << "(" << column << ")" << message << "\n";
45 *error_text_ += ss.str();
46 }
47
RecordWarning(int line,ColumnNumber column,absl::string_view message)48 void RecordWarning(int line, ColumnNumber column, absl::string_view message) override {
49 std::ostringstream ss;
50 ss << "WARNING: " << line << "(" << column << ")" << message << "\n";
51 *error_text_ += ss.str();
52 }
53
54 private:
55 std::string* error_text_;
56 StringErrorCollector(const StringErrorCollector&) = delete;
57 StringErrorCollector& operator=(const StringErrorCollector&) = delete;
58 };
59
ParsePartialFromAscii(const std::string & pb_ascii,google::protobuf::Message * proto,std::string * error_text)60 bool ParsePartialFromAscii(const std::string& pb_ascii, google::protobuf::Message* proto,
61 std::string* error_text) {
62 google::protobuf::TextFormat::Parser parser;
63 StringErrorCollector collector(error_text);
64 parser.RecordErrorsTo(&collector);
65 parser.AllowPartialMessage(true);
66 return parser.ParseFromString(pb_ascii, proto);
67 }
68
69 // Returns true iff p and q can be compared (i.e. have the same descriptor).
ProtoComparable(const google::protobuf::Message & p,const google::protobuf::Message & q)70 bool ProtoComparable(const google::protobuf::Message& p, const google::protobuf::Message& q) {
71 return p.GetDescriptor() == q.GetDescriptor();
72 }
73
74 template <typename Container>
JoinStringPieces(const Container & strings,std::string_view separator)75 std::string JoinStringPieces(const Container& strings, std::string_view separator) {
76 std::stringstream stream;
77 std::string_view sep = "";
78 for (const std::string_view str : strings) {
79 stream << sep << str;
80 sep = separator;
81 }
82 return stream.str();
83 }
84
85 // Find all the descriptors for the ignore_fields.
GetFieldDescriptors(const google::protobuf::Descriptor * proto_descriptor,const std::vector<std::string> & ignore_fields)86 std::vector<const google::protobuf::FieldDescriptor*> GetFieldDescriptors(
87 const google::protobuf::Descriptor* proto_descriptor,
88 const std::vector<std::string>& ignore_fields) {
89 std::vector<const google::protobuf::FieldDescriptor*> ignore_descriptors;
90 std::vector<std::string_view> remaining_descriptors;
91
92 const google::protobuf::DescriptorPool* pool = proto_descriptor->file()->pool();
93 for (const std::string& name : ignore_fields) {
94 if (const google::protobuf::FieldDescriptor* field = pool->FindFieldByName(name)) {
95 ignore_descriptors.push_back(field);
96 } else {
97 remaining_descriptors.push_back(name);
98 }
99 }
100
101 DCHECK(remaining_descriptors.empty())
102 << "Could not find fields for proto " << proto_descriptor->full_name()
103 << " with fully qualified names: " << JoinStringPieces(remaining_descriptors, ",");
104 return ignore_descriptors;
105 }
106
107 // Sets the ignored fields corresponding to ignore_fields in differencer. Dies
108 // if any is invalid.
SetIgnoredFieldsOrDie(const google::protobuf::Descriptor & root_descriptor,const std::vector<std::string> & ignore_fields,google::protobuf::util::MessageDifferencer * differencer)109 void SetIgnoredFieldsOrDie(const google::protobuf::Descriptor& root_descriptor,
110 const std::vector<std::string>& ignore_fields,
111 google::protobuf::util::MessageDifferencer* differencer) {
112 if (!ignore_fields.empty()) {
113 std::vector<const google::protobuf::FieldDescriptor*> ignore_descriptors =
114 GetFieldDescriptors(&root_descriptor, ignore_fields);
115 for (std::vector<const google::protobuf::FieldDescriptor*>::iterator it =
116 ignore_descriptors.begin();
117 it != ignore_descriptors.end(); ++it) {
118 differencer->IgnoreField(*it);
119 }
120 }
121 }
122
123 // Configures a MessageDifferencer and DefaultFieldComparator to use the logic
124 // described in comp. The configured differencer is the output of this function,
125 // but a FieldComparator must be provided to keep ownership clear.
ConfigureDifferencer(const internal::ProtoComparison & comp,google::protobuf::util::DefaultFieldComparator * comparator,google::protobuf::util::MessageDifferencer * differencer,const google::protobuf::Descriptor * descriptor)126 void ConfigureDifferencer(const internal::ProtoComparison& comp,
127 google::protobuf::util::DefaultFieldComparator* comparator,
128 google::protobuf::util::MessageDifferencer* differencer,
129 const google::protobuf::Descriptor* descriptor) {
130 differencer->set_message_field_comparison(comp.field_comp);
131 differencer->set_scope(comp.scope);
132 comparator->set_float_comparison(comp.float_comp);
133 comparator->set_treat_nan_as_equal(comp.treating_nan_as_equal);
134 differencer->set_repeated_field_comparison(comp.repeated_field_comp);
135 SetIgnoredFieldsOrDie(*descriptor, comp.ignore_fields, differencer);
136 if (comp.float_comp == internal::kProtoApproximate &&
137 (comp.has_custom_margin || comp.has_custom_fraction)) {
138 // Two fields will be considered equal if they're within the fraction
139 // _or_ within the margin. So setting the fraction to 0.0 makes this
140 // effectively a "SetMargin". Similarly, setting the margin to 0.0 makes
141 // this effectively a "SetFraction".
142 comparator->SetDefaultFractionAndMargin(comp.float_fraction, comp.float_margin);
143 }
144 differencer->set_field_comparator(comparator);
145 }
146
147 // Returns true iff actual and expected are comparable and match. The
148 // comp argument specifies how two are compared.
ProtoCompare(const internal::ProtoComparison & comp,const google::protobuf::Message & actual,const google::protobuf::Message & expected)149 bool ProtoCompare(const internal::ProtoComparison& comp, const google::protobuf::Message& actual,
150 const google::protobuf::Message& expected) {
151 if (!ProtoComparable(actual, expected)) return false;
152
153 google::protobuf::util::MessageDifferencer differencer;
154 google::protobuf::util::DefaultFieldComparator field_comparator;
155 ConfigureDifferencer(comp, &field_comparator, &differencer, actual.GetDescriptor());
156
157 // It's important for 'expected' to be the first argument here, as
158 // Compare() is not symmetric. When we do a partial comparison,
159 // only fields present in the first argument of Compare() are
160 // considered.
161 return differencer.Compare(expected, actual);
162 }
163
164 // Describes the types of the expected and the actual protocol buffer.
DescribeTypes(const google::protobuf::Message & expected,const google::protobuf::Message & actual)165 std::string DescribeTypes(const google::protobuf::Message& expected,
166 const google::protobuf::Message& actual) {
167 return "whose type should be " + expected.GetDescriptor()->full_name() + " but actually is " +
168 actual.GetDescriptor()->full_name();
169 }
170
171 // Prints the protocol buffer pointed to by proto.
PrintProtoPointee(const google::protobuf::Message * proto)172 std::string PrintProtoPointee(const google::protobuf::Message* proto) {
173 if (proto == NULL) return "";
174
175 return "which points to " + ::testing::PrintToString(*proto);
176 }
177
178 // Describes the differences between the two protocol buffers.
DescribeDiff(const internal::ProtoComparison & comp,const google::protobuf::Message & actual,const google::protobuf::Message & expected)179 std::string DescribeDiff(const internal::ProtoComparison& comp,
180 const google::protobuf::Message& actual,
181 const google::protobuf::Message& expected) {
182 google::protobuf::util::MessageDifferencer differencer;
183 google::protobuf::util::DefaultFieldComparator field_comparator;
184 ConfigureDifferencer(comp, &field_comparator, &differencer, actual.GetDescriptor());
185
186 std::string diff;
187 differencer.ReportDifferencesToString(&diff);
188
189 // We must put 'expected' as the first argument here, as Compare()
190 // reports the diff in terms of how the protobuf changes from the
191 // first argument to the second argument.
192 differencer.Compare(expected, actual);
193
194 // Removes the trailing '\n' in the diff to make the output look nicer.
195 if (diff.length() > 0 && *(diff.end() - 1) == '\n') {
196 diff.erase(diff.end() - 1);
197 }
198
199 return "with the difference:\n" + diff;
200 }
201
MatchAndExplain(const google::protobuf::Message & arg,bool is_matcher_for_pointer,::testing::MatchResultListener * listener) const202 bool ProtoMatcherBase::MatchAndExplain(
203 const google::protobuf::Message& arg,
204 bool is_matcher_for_pointer, // true iff this matcher is used to match
205 // a protobuf pointer.
206 ::testing::MatchResultListener* listener) const {
207 if (must_be_initialized_ && !arg.IsInitialized()) {
208 *listener << "which isn't fully initialized";
209 return false;
210 }
211
212 const google::protobuf::Message* const expected = CreateExpectedProto(arg, listener);
213 if (expected == NULL) return false;
214
215 // Protobufs of different types cannot be compared.
216 const bool comparable = ProtoComparable(arg, *expected);
217 const bool match = comparable && ProtoCompare(comp(), arg, *expected);
218
219 // Explaining the match result is expensive. We don't want to waste
220 // time calculating an explanation if the listener isn't interested.
221 if (listener->IsInterested()) {
222 const char* sep = "";
223 if (is_matcher_for_pointer) {
224 *listener << PrintProtoPointee(&arg);
225 sep = ",\n";
226 }
227
228 if (!comparable) {
229 *listener << sep << DescribeTypes(*expected, arg);
230 } else if (!match) {
231 *listener << sep << DescribeDiff(comp(), arg, *expected);
232 }
233 }
234
235 DeleteExpectedProto(expected);
236 return match;
237 }
238
239 } // namespace internal
240 } // namespace android
241