1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/node_matchers.h"
17
18 #include <utility>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/graph/graph_node_util.h"
29
30 namespace tensorflow {
31 namespace testing {
32 namespace matchers {
33 namespace {
34
35 using impl::NodeMatcherProperties;
36 using impl::OutEdge;
37
IndentAllButFirstLine(absl::string_view text)38 string IndentAllButFirstLine(absl::string_view text) {
39 std::vector<std::string> lines = absl::StrSplit(text, '\n');
40 for (int i = 1; i < lines.size(); i++) {
41 lines[i].insert(0, " ");
42 }
43 return absl::StrJoin(lines, "\n");
44 }
45
46 template <typename T>
CompareTensor(const Tensor & actual,const Tensor & expected,::testing::MatchResultListener * listener)47 bool CompareTensor(const Tensor& actual, const Tensor& expected,
48 ::testing::MatchResultListener* listener) {
49 if (actual.NumElements() != expected.NumElements()) {
50 if (listener->IsInterested()) {
51 *listener << "\nwas looking for tensor with " << expected.NumElements()
52 << " elements, found tensor with " << actual.NumElements()
53 << " elements";
54 return false;
55 }
56 }
57
58 for (int64_t i = 0, e = actual.NumElements(); i < e; i++) {
59 if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
60 *listener << "\nmismatch in constant tensor at index " << i
61 << " expected = " << expected.flat<T>()(i)
62 << " actual = " << actual.flat<T>()(i);
63 return false;
64 }
65 }
66
67 return true;
68 }
69
MatchAndExplainTensor(const Tensor & tensor,const Tensor & expected_tensor,::testing::MatchResultListener * listener)70 bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
71 ::testing::MatchResultListener* listener) {
72 if (tensor.dtype() != expected_tensor.dtype()) {
73 if (listener->IsInterested()) {
74 *listener << "\nexpected tensor of type "
75 << DataType_Name(expected_tensor.dtype())
76 << " but found one of type " << DataType_Name(tensor.dtype());
77 return false;
78 }
79 }
80
81 switch (tensor.dtype()) {
82 case DT_HALF:
83 return CompareTensor<Eigen::half>(tensor, expected_tensor, listener);
84 case DT_FLOAT:
85 return CompareTensor<float>(tensor, expected_tensor, listener);
86 case DT_DOUBLE:
87 return CompareTensor<double>(tensor, expected_tensor, listener);
88 case DT_INT8:
89 return CompareTensor<int8>(tensor, expected_tensor, listener);
90 case DT_INT16:
91 return CompareTensor<int16>(tensor, expected_tensor, listener);
92 case DT_INT32:
93 return CompareTensor<int32>(tensor, expected_tensor, listener);
94 case DT_INT64:
95 return CompareTensor<int64_t>(tensor, expected_tensor, listener);
96 case DT_UINT8:
97 return CompareTensor<uint8>(tensor, expected_tensor, listener);
98 case DT_UINT16:
99 return CompareTensor<uint16>(tensor, expected_tensor, listener);
100 case DT_UINT32:
101 return CompareTensor<uint32>(tensor, expected_tensor, listener);
102 case DT_UINT64:
103 return CompareTensor<uint64>(tensor, expected_tensor, listener);
104 default:
105 LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly.
106 << DataType_Name(tensor.dtype());
107 }
108 }
109
110 struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
MatchAndExplaintensorflow::testing::matchers::__anon31c7a6db0111::NodeMatcher111 bool MatchAndExplain(
112 const Node* node,
113 ::testing::MatchResultListener* listener) const override {
114 if (op && node->type_string() != *op) {
115 if (listener->IsInterested()) {
116 *listener << "\nexpected op " << *op << " but found "
117 << node->type_string();
118 }
119 return false;
120 }
121
122 if (assigned_device && node->assigned_device_name() != *assigned_device) {
123 if (listener->IsInterested()) {
124 *listener << "\nexpected assigned_device " << *assigned_device
125 << " but found \"" << node->assigned_device_name() << "\"";
126 }
127 return false;
128 }
129
130 if (name && node->name() != *name) {
131 if (listener->IsInterested()) {
132 *listener << "\nexpected name " << *name << " but found "
133 << node->name();
134 }
135 return false;
136 }
137
138 if (constant_value) {
139 const TensorProto* proto = nullptr;
140 if (!TryGetNodeAttr(node->def(), "value", &proto)) {
141 if (listener->IsInterested()) {
142 *listener << "\ncould not find \"value\" attribute in node";
143 }
144 return false;
145 }
146
147 Tensor tensor(proto->dtype());
148 if (!tensor.FromProto(*proto)) {
149 if (listener->IsInterested()) {
150 *listener << "\ncould not convert TensorProto in \"value\" attribute "
151 "to Tensor";
152 }
153 return false;
154 }
155
156 if (!MatchAndExplainTensor(/*tensor=*/tensor,
157 /*expected_tensor=*/*constant_value,
158 listener)) {
159 return false;
160 }
161 }
162
163 if (input_matchers) {
164 if (input_matchers->size() != node->num_inputs()) {
165 if (listener->IsInterested()) {
166 *listener << "\nexpected " << input_matchers->size()
167 << " inputs but node has " << node->num_inputs();
168 }
169 return false;
170 }
171
172 for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
173 input_idx++) {
174 if (!MatchAndExplainInput(node, input_idx, listener)) {
175 return false;
176 }
177 }
178 }
179
180 std::vector<const Node*> control_deps;
181 for (const Edge* e : node->in_edges()) {
182 if (e->IsControlEdge()) {
183 control_deps.push_back(e->src());
184 }
185 }
186
187 ::testing::StringMatchResultListener inner_listener;
188 if (control_dep_set &&
189 !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
190 if (listener->IsInterested()) {
191 string explanation = inner_listener.str();
192 if (!explanation.empty()) {
193 explanation = absl::StrCat(", ", explanation, ",");
194 }
195 *listener << "ctrl_deps" << explanation << " does not match expected: ";
196 control_dep_set->DescribeTo(listener->stream());
197 }
198 return false;
199 }
200
201 const AttrValueMap attr_value_map = node->def().attr();
202 for (const auto& attr_kv_pair : attrs) {
203 auto it = attr_value_map.find(attr_kv_pair.first);
204 if (it == attr_value_map.end()) {
205 if (listener->IsInterested()) {
206 *listener << "did not find attribute named \"" << attr_kv_pair.first
207 << "\" in node";
208 }
209 return false;
210 }
211 if (attr_kv_pair.second &&
212 !AreAttrValuesEqual(it->second, *attr_kv_pair.second)) {
213 if (listener->IsInterested()) {
214 *listener << "attribute named " << attr_kv_pair.first
215 << " does not match value; expected: \""
216 << SummarizeAttrValue(*attr_kv_pair.second)
217 << "\", found: \"" << SummarizeAttrValue(it->second)
218 << "\"";
219 }
220 return false;
221 }
222 }
223
224 return true;
225 }
226
DescribeTotensorflow::testing::matchers::__anon31c7a6db0111::NodeMatcher227 void DescribeTo(::std::ostream* os) const override {
228 std::vector<string> predicates;
229
230 if (name) {
231 predicates.push_back(absl::StrCat("name: ", *name));
232 }
233
234 if (op) {
235 predicates.push_back(absl::StrCat("op: ", *op));
236 }
237
238 if (assigned_device) {
239 predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
240 }
241
242 bool printed_something = !predicates.empty();
243
244 *os << absl::StrJoin(predicates, ", ");
245
246 if (constant_value) {
247 printed_something = true;
248 *os << "constant value: " << constant_value->DebugString();
249 }
250
251 if (input_matchers) {
252 if (!input_matchers->empty()) {
253 printed_something = true;
254 *os << " with " << (input_matchers->size() == 1 ? "only " : "")
255 << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
256 }
257
258 if (input_matchers->size() == 1) {
259 ::std::stringstream ss;
260 input_matchers->front().DescribeTo(&ss);
261 printed_something = true;
262 *os << "matching " << ss.str();
263 } else {
264 int edge_idx = 0;
265 for (const ::testing::Matcher<OutEdge>& matcher : (*input_matchers)) {
266 *os << "\n [" << edge_idx << "] matching (";
267 ::std::stringstream ss;
268 matcher.DescribeTo(&ss);
269 printed_something = true;
270 *os << IndentAllButFirstLine(ss.str());
271 *os << ")";
272 edge_idx++;
273 }
274 }
275 }
276
277 if (control_dep_set) {
278 printed_something = true;
279 *os << " and control deps ";
280 control_dep_set->DescribeTo(os);
281 }
282
283 if (!attrs.empty()) {
284 printed_something = true;
285 std::vector<string> attrs_str;
286 absl::c_transform(
287 attrs, std::back_inserter(attrs_str),
288 [](const std::pair<string, std::optional<AttrValue>>& attr_kv_pair) {
289 return absl::StrCat(attr_kv_pair.first, "->",
290 attr_kv_pair.second
291 ? SummarizeAttrValue(*attr_kv_pair.second)
292 : "*");
293 });
294 *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ")
295 << "]";
296 }
297
298 if (!printed_something) {
299 *os << "is any node";
300 }
301 }
302
MatchAndExplainInputtensorflow::testing::matchers::__anon31c7a6db0111::NodeMatcher303 bool MatchAndExplainInput(const Node* node, int input_idx,
304 ::testing::MatchResultListener* listener) const {
305 const Edge* edge;
306 if (!node->input_edge(input_idx, &edge).ok()) {
307 if (listener->IsInterested()) {
308 *listener << "\ncould not find incoming edge for input " << input_idx;
309 }
310 return false;
311 }
312
313 ::testing::StringMatchResultListener inner_listener;
314 OutEdge input = {edge->src(), edge->src_output()};
315 if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
316 return true;
317 }
318
319 if (listener->IsInterested()) {
320 *listener << "\ninput " << input_idx << " does not match expected:\n";
321 (*input_matchers)[input_idx].DescribeTo(listener->stream());
322 string explanation = inner_listener.str();
323 if (!explanation.empty()) {
324 *listener << ", " << explanation;
325 }
326 }
327 return false;
328 }
329
330 std::optional<string> op;
331 std::optional<string> name;
332 std::optional<string> assigned_device;
333 std::optional<Tensor> constant_value;
334 std::optional<std::vector<::testing::Matcher<OutEdge>>> input_matchers;
335 std::optional<::testing::Matcher<absl::Span<const Node* const>>>
336 control_dep_set;
337 std::map<string, std::optional<AttrValue>> attrs;
338 };
339
340 // Matches a dst and dst_output on an input edge. Today we only use this with
341 // dst_output=0 but we will eventually need to support multi-output operations.
342 class OutEdgeMatcher : public ::testing::MatcherInterface<OutEdge> {
343 public:
OutEdgeMatcher(::testing::Matcher<const Node * > src_matcher,int src_oidx)344 OutEdgeMatcher(::testing::Matcher<const Node*> src_matcher, int src_oidx)
345 : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {}
346
MatchAndExplain(OutEdge out_edge,::testing::MatchResultListener * listener) const347 bool MatchAndExplain(
348 OutEdge out_edge,
349 ::testing::MatchResultListener* listener) const override {
350 ::testing::StringMatchResultListener inner_listener;
351 if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) {
352 if (listener->IsInterested()) {
353 *listener << "\nsource does not match expected ";
354 src_matcher_.DescribeTo(listener->stream());
355 string explanation = inner_listener.str();
356 if (!explanation.empty()) {
357 *listener << "\n\t" << explanation;
358 }
359 }
360 return false;
361 }
362 if (out_edge.second != src_oidx_) {
363 if (listener->IsInterested()) {
364 *listener << "\nexpected output slot to be " << src_oidx_
365 << " but found " << out_edge.second;
366 }
367 return false;
368 }
369
370 return true;
371 }
372
DescribeTo(::std::ostream * os) const373 void DescribeTo(::std::ostream* os) const override {
374 if (src_oidx_) {
375 *os << "output slot: " << src_oidx_ << ", source: (";
376 }
377
378 src_matcher_.DescribeTo(os);
379
380 if (src_oidx_) {
381 *os << ")";
382 }
383 }
384
385 private:
386 ::testing::Matcher<const Node*> src_matcher_;
387 int src_oidx_;
388 };
389 } // namespace
390
NodeWith(absl::Span<const NodeMatcherProperties> props)391 ::testing::Matcher<const Node*> impl::NodeWith(
392 absl::Span<const NodeMatcherProperties> props) {
393 NodeMatcher* matcher = new NodeMatcher();
394 for (const NodeMatcherProperties& prop : props) {
395 if (prop.name()) {
396 DCHECK(!matcher->name);
397 matcher->name = prop.name();
398 }
399
400 if (prop.op()) {
401 DCHECK(!matcher->op);
402 matcher->op = prop.op();
403 }
404
405 if (prop.constant_value()) {
406 DCHECK(!matcher->constant_value);
407 matcher->constant_value = prop.constant_value();
408 }
409
410 if (prop.assigned_device()) {
411 DCHECK(!matcher->assigned_device);
412 matcher->assigned_device = prop.assigned_device();
413 }
414
415 if (prop.inputs()) {
416 DCHECK(!matcher->input_matchers);
417 matcher->input_matchers = *prop.inputs();
418 }
419
420 if (prop.control_deps()) {
421 DCHECK(!matcher->control_dep_set);
422 matcher->control_dep_set =
423 ::testing::UnorderedElementsAreArray(*prop.control_deps());
424 }
425
426 if (prop.attr()) {
427 auto insert_result = matcher->attrs.insert(*prop.attr());
428 DCHECK(insert_result.second);
429 }
430 }
431
432 return ::testing::MakeMatcher(matcher);
433 }
434
Name(string name)435 impl::NodeMatcherProperties Name(string name) {
436 impl::NodeMatcherProperties props;
437 props.set_name(std::move(name));
438 return props;
439 }
440
441 // Matches a node with op `op`.
Op(string op)442 impl::NodeMatcherProperties Op(string op) {
443 impl::NodeMatcherProperties props;
444 props.set_op(std::move(op));
445 return props;
446 }
447
448 // Matches a node with assigned device `assigned_device`.
AssignedDevice(string assigned_device)449 impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
450 impl::NodeMatcherProperties props;
451 props.set_assigned_device(std::move(assigned_device));
452 return props;
453 }
454
Inputs(absl::Span<const::testing::Matcher<OutEdge>> inputs)455 impl::NodeMatcherProperties impl::Inputs(
456 absl::Span<const ::testing::Matcher<OutEdge>> inputs) {
457 std::vector<::testing::Matcher<OutEdge>> inputs_vector;
458 absl::c_copy(inputs, std::back_inserter(inputs_vector));
459
460 impl::NodeMatcherProperties props;
461 props.set_inputs(std::move(inputs_vector));
462 return props;
463 }
464
CtrlDeps(absl::Span<const::testing::Matcher<const Node * >> control_deps)465 impl::NodeMatcherProperties impl::CtrlDeps(
466 absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
467 std::vector<::testing::Matcher<const Node*>> control_deps_vector;
468 absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
469
470 impl::NodeMatcherProperties props;
471 props.set_control_deps(std::move(control_deps_vector));
472 return props;
473 }
474
AttrLiteralHelper(const std::pair<string,bool> & bool_attr)475 std::pair<string, AttrValue> impl::AttrLiteralHelper(
476 const std::pair<string, bool>& bool_attr) {
477 AttrValue attr_value;
478 attr_value.set_b(bool_attr.second);
479 return {bool_attr.first, attr_value};
480 }
481
AttrLiteralHelper(const std::pair<string,absl::Span<const int>> & int_list_attr)482 std::pair<string, AttrValue> impl::AttrLiteralHelper(
483 const std::pair<string, absl::Span<const int>>& int_list_attr) {
484 AttrValue attr_value;
485 AttrValue::ListValue* list = attr_value.mutable_list();
486 for (int i : int_list_attr.second) {
487 list->add_i(i);
488 }
489 return {int_list_attr.first, attr_value};
490 }
491
AttrLiteralHelper(const std::pair<string,absl::Span<const string>> & string_list_attr)492 std::pair<string, AttrValue> impl::AttrLiteralHelper(
493 const std::pair<string, absl::Span<const string>>& string_list_attr) {
494 AttrValue attr_value;
495 AttrValue::ListValue* list = attr_value.mutable_list();
496 for (const string& s : string_list_attr.second) {
497 list->add_s(s);
498 }
499 return {string_list_attr.first, attr_value};
500 }
501
Attr(std::pair<string,AttrValue> attr)502 impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
503 impl::NodeMatcherProperties props;
504 props.set_attr(std::move(attr));
505 return props;
506 }
507
Attr(string name)508 impl::NodeMatcherProperties impl::Attr(string name) {
509 impl::NodeMatcherProperties props;
510 props.set_attr({std::move(name), std::nullopt});
511 return props;
512 }
513
ConstantValue(const::tensorflow::Input::Initializer & val)514 NodeMatcherProperties ConstantValue(
515 const ::tensorflow::Input::Initializer& val) {
516 TF_CHECK_OK(val.status);
517 NodeMatcherProperties props;
518 props.set_constant_value(val.tensor);
519 return props;
520 }
521
Const(const::tensorflow::Input::Initializer & val)522 ::testing::Matcher<impl::OutEdge> Const(
523 const ::tensorflow::Input::Initializer& val) {
524 return Out(NodeWith(ConstantValue(val)));
525 }
Out(int oidx,::testing::Matcher<const Node * > node_matcher)526 ::testing::Matcher<impl::OutEdge> Out(
527 int oidx, ::testing::Matcher<const Node*> node_matcher) {
528 return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx));
529 }
530 } // namespace matchers
531
FindNodeByName(Graph * g,absl::string_view name)532 Node* FindNodeByName(Graph* g, absl::string_view name) {
533 for (Node* n : g->nodes()) {
534 if (n->name() == name) {
535 return n;
536 }
537 }
538
539 return nullptr;
540 }
541 } // namespace testing
542
PrintTo(const Node * n,::std::ostream * os)543 void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
PrintTo(Node * n,::std::ostream * os)544 void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
545 } // namespace tensorflow
546