xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/node_matchers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/node_matchers.h"
17 
18 #include <utility>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/strings/str_replace.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/graph/graph_node_util.h"
29 
30 namespace tensorflow {
31 namespace testing {
32 namespace matchers {
33 namespace {
34 
35 using impl::NodeMatcherProperties;
36 using impl::OutEdge;
37 
IndentAllButFirstLine(absl::string_view text)38 string IndentAllButFirstLine(absl::string_view text) {
39   std::vector<std::string> lines = absl::StrSplit(text, '\n');
40   for (int i = 1; i < lines.size(); i++) {
41     lines[i].insert(0, "  ");
42   }
43   return absl::StrJoin(lines, "\n");
44 }
45 
46 template <typename T>
CompareTensor(const Tensor & actual,const Tensor & expected,::testing::MatchResultListener * listener)47 bool CompareTensor(const Tensor& actual, const Tensor& expected,
48                    ::testing::MatchResultListener* listener) {
49   if (actual.NumElements() != expected.NumElements()) {
50     if (listener->IsInterested()) {
51       *listener << "\nwas looking for tensor with " << expected.NumElements()
52                 << " elements, found tensor with " << actual.NumElements()
53                 << " elements";
54       return false;
55     }
56   }
57 
58   for (int64_t i = 0, e = actual.NumElements(); i < e; i++) {
59     if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
60       *listener << "\nmismatch in constant tensor at index " << i
61                 << " expected = " << expected.flat<T>()(i)
62                 << " actual = " << actual.flat<T>()(i);
63       return false;
64     }
65   }
66 
67   return true;
68 }
69 
MatchAndExplainTensor(const Tensor & tensor,const Tensor & expected_tensor,::testing::MatchResultListener * listener)70 bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
71                            ::testing::MatchResultListener* listener) {
72   if (tensor.dtype() != expected_tensor.dtype()) {
73     if (listener->IsInterested()) {
74       *listener << "\nexpected tensor of type "
75                 << DataType_Name(expected_tensor.dtype())
76                 << " but found one of type " << DataType_Name(tensor.dtype());
77       return false;
78     }
79   }
80 
81   switch (tensor.dtype()) {
82     case DT_HALF:
83       return CompareTensor<Eigen::half>(tensor, expected_tensor, listener);
84     case DT_FLOAT:
85       return CompareTensor<float>(tensor, expected_tensor, listener);
86     case DT_DOUBLE:
87       return CompareTensor<double>(tensor, expected_tensor, listener);
88     case DT_INT8:
89       return CompareTensor<int8>(tensor, expected_tensor, listener);
90     case DT_INT16:
91       return CompareTensor<int16>(tensor, expected_tensor, listener);
92     case DT_INT32:
93       return CompareTensor<int32>(tensor, expected_tensor, listener);
94     case DT_INT64:
95       return CompareTensor<int64_t>(tensor, expected_tensor, listener);
96     case DT_UINT8:
97       return CompareTensor<uint8>(tensor, expected_tensor, listener);
98     case DT_UINT16:
99       return CompareTensor<uint16>(tensor, expected_tensor, listener);
100     case DT_UINT32:
101       return CompareTensor<uint32>(tensor, expected_tensor, listener);
102     case DT_UINT64:
103       return CompareTensor<uint64>(tensor, expected_tensor, listener);
104     default:
105       LOG(FATAL) << "Unsupported dtype "  // Crash ok: testonly.
106                  << DataType_Name(tensor.dtype());
107   }
108 }
109 
110 struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
MatchAndExplaintensorflow::testing::matchers::__anon31c7a6db0111::NodeMatcher111   bool MatchAndExplain(
112       const Node* node,
113       ::testing::MatchResultListener* listener) const override {
114     if (op && node->type_string() != *op) {
115       if (listener->IsInterested()) {
116         *listener << "\nexpected op " << *op << " but found "
117                   << node->type_string();
118       }
119       return false;
120     }
121 
122     if (assigned_device && node->assigned_device_name() != *assigned_device) {
123       if (listener->IsInterested()) {
124         *listener << "\nexpected assigned_device " << *assigned_device
125                   << " but found \"" << node->assigned_device_name() << "\"";
126       }
127       return false;
128     }
129 
130     if (name && node->name() != *name) {
131       if (listener->IsInterested()) {
132         *listener << "\nexpected name " << *name << " but found "
133                   << node->name();
134       }
135       return false;
136     }
137 
138     if (constant_value) {
139       const TensorProto* proto = nullptr;
140       if (!TryGetNodeAttr(node->def(), "value", &proto)) {
141         if (listener->IsInterested()) {
142           *listener << "\ncould not find \"value\" attribute in node";
143         }
144         return false;
145       }
146 
147       Tensor tensor(proto->dtype());
148       if (!tensor.FromProto(*proto)) {
149         if (listener->IsInterested()) {
150           *listener << "\ncould not convert TensorProto in \"value\" attribute "
151                        "to Tensor";
152         }
153         return false;
154       }
155 
156       if (!MatchAndExplainTensor(/*tensor=*/tensor,
157                                  /*expected_tensor=*/*constant_value,
158                                  listener)) {
159         return false;
160       }
161     }
162 
163     if (input_matchers) {
164       if (input_matchers->size() != node->num_inputs()) {
165         if (listener->IsInterested()) {
166           *listener << "\nexpected " << input_matchers->size()
167                     << " inputs but node has " << node->num_inputs();
168         }
169         return false;
170       }
171 
172       for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
173            input_idx++) {
174         if (!MatchAndExplainInput(node, input_idx, listener)) {
175           return false;
176         }
177       }
178     }
179 
180     std::vector<const Node*> control_deps;
181     for (const Edge* e : node->in_edges()) {
182       if (e->IsControlEdge()) {
183         control_deps.push_back(e->src());
184       }
185     }
186 
187     ::testing::StringMatchResultListener inner_listener;
188     if (control_dep_set &&
189         !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
190       if (listener->IsInterested()) {
191         string explanation = inner_listener.str();
192         if (!explanation.empty()) {
193           explanation = absl::StrCat(", ", explanation, ",");
194         }
195         *listener << "ctrl_deps" << explanation << " does not match expected: ";
196         control_dep_set->DescribeTo(listener->stream());
197       }
198       return false;
199     }
200 
201     const AttrValueMap attr_value_map = node->def().attr();
202     for (const auto& attr_kv_pair : attrs) {
203       auto it = attr_value_map.find(attr_kv_pair.first);
204       if (it == attr_value_map.end()) {
205         if (listener->IsInterested()) {
206           *listener << "did not find attribute named \"" << attr_kv_pair.first
207                     << "\" in node";
208         }
209         return false;
210       }
211       if (attr_kv_pair.second &&
212           !AreAttrValuesEqual(it->second, *attr_kv_pair.second)) {
213         if (listener->IsInterested()) {
214           *listener << "attribute named " << attr_kv_pair.first
215                     << " does not match value; expected: \""
216                     << SummarizeAttrValue(*attr_kv_pair.second)
217                     << "\", found: \"" << SummarizeAttrValue(it->second)
218                     << "\"";
219         }
220         return false;
221       }
222     }
223 
224     return true;
225   }
226 
DescribeTotensorflow::testing::matchers::__anon31c7a6db0111::NodeMatcher227   void DescribeTo(::std::ostream* os) const override {
228     std::vector<string> predicates;
229 
230     if (name) {
231       predicates.push_back(absl::StrCat("name: ", *name));
232     }
233 
234     if (op) {
235       predicates.push_back(absl::StrCat("op: ", *op));
236     }
237 
238     if (assigned_device) {
239       predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
240     }
241 
242     bool printed_something = !predicates.empty();
243 
244     *os << absl::StrJoin(predicates, ", ");
245 
246     if (constant_value) {
247       printed_something = true;
248       *os << "constant value: " << constant_value->DebugString();
249     }
250 
251     if (input_matchers) {
252       if (!input_matchers->empty()) {
253         printed_something = true;
254         *os << " with " << (input_matchers->size() == 1 ? "only " : "")
255             << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
256       }
257 
258       if (input_matchers->size() == 1) {
259         ::std::stringstream ss;
260         input_matchers->front().DescribeTo(&ss);
261         printed_something = true;
262         *os << "matching " << ss.str();
263       } else {
264         int edge_idx = 0;
265         for (const ::testing::Matcher<OutEdge>& matcher : (*input_matchers)) {
266           *os << "\n  [" << edge_idx << "] matching (";
267           ::std::stringstream ss;
268           matcher.DescribeTo(&ss);
269           printed_something = true;
270           *os << IndentAllButFirstLine(ss.str());
271           *os << ")";
272           edge_idx++;
273         }
274       }
275     }
276 
277     if (control_dep_set) {
278       printed_something = true;
279       *os << " and control deps ";
280       control_dep_set->DescribeTo(os);
281     }
282 
283     if (!attrs.empty()) {
284       printed_something = true;
285       std::vector<string> attrs_str;
286       absl::c_transform(
287           attrs, std::back_inserter(attrs_str),
288           [](const std::pair<string, std::optional<AttrValue>>& attr_kv_pair) {
289             return absl::StrCat(attr_kv_pair.first, "->",
290                                 attr_kv_pair.second
291                                     ? SummarizeAttrValue(*attr_kv_pair.second)
292                                     : "*");
293           });
294       *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ")
295           << "]";
296     }
297 
298     if (!printed_something) {
299       *os << "is any node";
300     }
301   }
302 
MatchAndExplainInputtensorflow::testing::matchers::__anon31c7a6db0111::NodeMatcher303   bool MatchAndExplainInput(const Node* node, int input_idx,
304                             ::testing::MatchResultListener* listener) const {
305     const Edge* edge;
306     if (!node->input_edge(input_idx, &edge).ok()) {
307       if (listener->IsInterested()) {
308         *listener << "\ncould not find incoming edge for input " << input_idx;
309       }
310       return false;
311     }
312 
313     ::testing::StringMatchResultListener inner_listener;
314     OutEdge input = {edge->src(), edge->src_output()};
315     if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
316       return true;
317     }
318 
319     if (listener->IsInterested()) {
320       *listener << "\ninput " << input_idx << " does not match expected:\n";
321       (*input_matchers)[input_idx].DescribeTo(listener->stream());
322       string explanation = inner_listener.str();
323       if (!explanation.empty()) {
324         *listener << ", " << explanation;
325       }
326     }
327     return false;
328   }
329 
330   std::optional<string> op;
331   std::optional<string> name;
332   std::optional<string> assigned_device;
333   std::optional<Tensor> constant_value;
334   std::optional<std::vector<::testing::Matcher<OutEdge>>> input_matchers;
335   std::optional<::testing::Matcher<absl::Span<const Node* const>>>
336       control_dep_set;
337   std::map<string, std::optional<AttrValue>> attrs;
338 };
339 
340 // Matches a dst and dst_output on an input edge.  Today we only use this with
341 // dst_output=0 but we will eventually need to support multi-output operations.
342 class OutEdgeMatcher : public ::testing::MatcherInterface<OutEdge> {
343  public:
OutEdgeMatcher(::testing::Matcher<const Node * > src_matcher,int src_oidx)344   OutEdgeMatcher(::testing::Matcher<const Node*> src_matcher, int src_oidx)
345       : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {}
346 
MatchAndExplain(OutEdge out_edge,::testing::MatchResultListener * listener) const347   bool MatchAndExplain(
348       OutEdge out_edge,
349       ::testing::MatchResultListener* listener) const override {
350     ::testing::StringMatchResultListener inner_listener;
351     if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) {
352       if (listener->IsInterested()) {
353         *listener << "\nsource does not match expected ";
354         src_matcher_.DescribeTo(listener->stream());
355         string explanation = inner_listener.str();
356         if (!explanation.empty()) {
357           *listener << "\n\t" << explanation;
358         }
359       }
360       return false;
361     }
362     if (out_edge.second != src_oidx_) {
363       if (listener->IsInterested()) {
364         *listener << "\nexpected output slot to be " << src_oidx_
365                   << " but found " << out_edge.second;
366       }
367       return false;
368     }
369 
370     return true;
371   }
372 
DescribeTo(::std::ostream * os) const373   void DescribeTo(::std::ostream* os) const override {
374     if (src_oidx_) {
375       *os << "output slot: " << src_oidx_ << ", source: (";
376     }
377 
378     src_matcher_.DescribeTo(os);
379 
380     if (src_oidx_) {
381       *os << ")";
382     }
383   }
384 
385  private:
386   ::testing::Matcher<const Node*> src_matcher_;
387   int src_oidx_;
388 };
389 }  // namespace
390 
NodeWith(absl::Span<const NodeMatcherProperties> props)391 ::testing::Matcher<const Node*> impl::NodeWith(
392     absl::Span<const NodeMatcherProperties> props) {
393   NodeMatcher* matcher = new NodeMatcher();
394   for (const NodeMatcherProperties& prop : props) {
395     if (prop.name()) {
396       DCHECK(!matcher->name);
397       matcher->name = prop.name();
398     }
399 
400     if (prop.op()) {
401       DCHECK(!matcher->op);
402       matcher->op = prop.op();
403     }
404 
405     if (prop.constant_value()) {
406       DCHECK(!matcher->constant_value);
407       matcher->constant_value = prop.constant_value();
408     }
409 
410     if (prop.assigned_device()) {
411       DCHECK(!matcher->assigned_device);
412       matcher->assigned_device = prop.assigned_device();
413     }
414 
415     if (prop.inputs()) {
416       DCHECK(!matcher->input_matchers);
417       matcher->input_matchers = *prop.inputs();
418     }
419 
420     if (prop.control_deps()) {
421       DCHECK(!matcher->control_dep_set);
422       matcher->control_dep_set =
423           ::testing::UnorderedElementsAreArray(*prop.control_deps());
424     }
425 
426     if (prop.attr()) {
427       auto insert_result = matcher->attrs.insert(*prop.attr());
428       DCHECK(insert_result.second);
429     }
430   }
431 
432   return ::testing::MakeMatcher(matcher);
433 }
434 
Name(string name)435 impl::NodeMatcherProperties Name(string name) {
436   impl::NodeMatcherProperties props;
437   props.set_name(std::move(name));
438   return props;
439 }
440 
441 // Matches a node with op `op`.
Op(string op)442 impl::NodeMatcherProperties Op(string op) {
443   impl::NodeMatcherProperties props;
444   props.set_op(std::move(op));
445   return props;
446 }
447 
448 // Matches a node with assigned device `assigned_device`.
AssignedDevice(string assigned_device)449 impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
450   impl::NodeMatcherProperties props;
451   props.set_assigned_device(std::move(assigned_device));
452   return props;
453 }
454 
Inputs(absl::Span<const::testing::Matcher<OutEdge>> inputs)455 impl::NodeMatcherProperties impl::Inputs(
456     absl::Span<const ::testing::Matcher<OutEdge>> inputs) {
457   std::vector<::testing::Matcher<OutEdge>> inputs_vector;
458   absl::c_copy(inputs, std::back_inserter(inputs_vector));
459 
460   impl::NodeMatcherProperties props;
461   props.set_inputs(std::move(inputs_vector));
462   return props;
463 }
464 
CtrlDeps(absl::Span<const::testing::Matcher<const Node * >> control_deps)465 impl::NodeMatcherProperties impl::CtrlDeps(
466     absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
467   std::vector<::testing::Matcher<const Node*>> control_deps_vector;
468   absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
469 
470   impl::NodeMatcherProperties props;
471   props.set_control_deps(std::move(control_deps_vector));
472   return props;
473 }
474 
AttrLiteralHelper(const std::pair<string,bool> & bool_attr)475 std::pair<string, AttrValue> impl::AttrLiteralHelper(
476     const std::pair<string, bool>& bool_attr) {
477   AttrValue attr_value;
478   attr_value.set_b(bool_attr.second);
479   return {bool_attr.first, attr_value};
480 }
481 
AttrLiteralHelper(const std::pair<string,absl::Span<const int>> & int_list_attr)482 std::pair<string, AttrValue> impl::AttrLiteralHelper(
483     const std::pair<string, absl::Span<const int>>& int_list_attr) {
484   AttrValue attr_value;
485   AttrValue::ListValue* list = attr_value.mutable_list();
486   for (int i : int_list_attr.second) {
487     list->add_i(i);
488   }
489   return {int_list_attr.first, attr_value};
490 }
491 
AttrLiteralHelper(const std::pair<string,absl::Span<const string>> & string_list_attr)492 std::pair<string, AttrValue> impl::AttrLiteralHelper(
493     const std::pair<string, absl::Span<const string>>& string_list_attr) {
494   AttrValue attr_value;
495   AttrValue::ListValue* list = attr_value.mutable_list();
496   for (const string& s : string_list_attr.second) {
497     list->add_s(s);
498   }
499   return {string_list_attr.first, attr_value};
500 }
501 
Attr(std::pair<string,AttrValue> attr)502 impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
503   impl::NodeMatcherProperties props;
504   props.set_attr(std::move(attr));
505   return props;
506 }
507 
Attr(string name)508 impl::NodeMatcherProperties impl::Attr(string name) {
509   impl::NodeMatcherProperties props;
510   props.set_attr({std::move(name), std::nullopt});
511   return props;
512 }
513 
ConstantValue(const::tensorflow::Input::Initializer & val)514 NodeMatcherProperties ConstantValue(
515     const ::tensorflow::Input::Initializer& val) {
516   TF_CHECK_OK(val.status);
517   NodeMatcherProperties props;
518   props.set_constant_value(val.tensor);
519   return props;
520 }
521 
Const(const::tensorflow::Input::Initializer & val)522 ::testing::Matcher<impl::OutEdge> Const(
523     const ::tensorflow::Input::Initializer& val) {
524   return Out(NodeWith(ConstantValue(val)));
525 }
Out(int oidx,::testing::Matcher<const Node * > node_matcher)526 ::testing::Matcher<impl::OutEdge> Out(
527     int oidx, ::testing::Matcher<const Node*> node_matcher) {
528   return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx));
529 }
530 }  // namespace matchers
531 
FindNodeByName(Graph * g,absl::string_view name)532 Node* FindNodeByName(Graph* g, absl::string_view name) {
533   for (Node* n : g->nodes()) {
534     if (n->name() == name) {
535       return n;
536     }
537   }
538 
539   return nullptr;
540 }
541 }  // namespace testing
542 
PrintTo(const Node * n,::std::ostream * os)543 void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
PrintTo(Node * n,::std::ostream * os)544 void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
545 }  // namespace tensorflow
546