1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17
18 #include <unistd.h>
19
20 #include <cmath>
21 #include <vector>
22
23 #include "absl/base/casts.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/platform/env.h"
29
30 using absl::StrAppend;
31 using absl::StrAppendFormat;
32 using absl::StrCat;
33
34 namespace xla {
35 namespace literal_comparison {
36 namespace {
37
38 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
39 // able to transparently access the raw 16-bit value contained within.
40 template <typename T>
GetRawValue(T val)41 T GetRawValue(T val) {
42 return val;
43 }
GetRawValue(Eigen::half val)44 uint16_t GetRawValue(Eigen::half val) {
45 return Eigen::numext::bit_cast<uint16_t>(val);
46 }
47
48 // Helper function for comparing a floating point type, FloatT, bitwise equal
49 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
50 // -- on miscompare, a nice error message is given in the AssertionFailure.
51 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64_t> multi_index)52 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
53 absl::Span<const int64_t> multi_index) {
54 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
55 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
56 return ulhs == urhs;
57 }
58
59 // Templated comparator that specializes for float equality comparison with the
60 // bitwise helper above (this is the un-specialized fallback, to just use the
61 // default gunit implementation).
62 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64_t> multi_index)63 bool CompareEqual(NativeT lhs, NativeT rhs,
64 absl::Span<const int64_t> multi_index) {
65 return lhs == rhs;
66 }
67
68 // Specializations for floating types that do bitwise comparisons when equality
69 // comparison is requested.
70 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64_t> multi_index)71 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
72 absl::Span<const int64_t> multi_index) {
73 return CompareFloatsBitwiseEqual<bfloat16, uint16_t>(lhs, rhs, multi_index);
74 }
75 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64_t> multi_index)76 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
77 absl::Span<const int64_t> multi_index) {
78 return CompareFloatsBitwiseEqual<Eigen::half, uint16_t>(lhs, rhs,
79 multi_index);
80 }
81 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64_t> multi_index)82 bool CompareEqual<float>(float lhs, float rhs,
83 absl::Span<const int64_t> multi_index) {
84 return CompareFloatsBitwiseEqual<float, uint32_t>(lhs, rhs, multi_index);
85 }
86 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64_t> multi_index)87 bool CompareEqual<double>(double lhs, double rhs,
88 absl::Span<const int64_t> multi_index) {
89 return CompareFloatsBitwiseEqual<double, uint64_t>(lhs, rhs, multi_index);
90 }
91 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64_t> multi_index)92 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
93 absl::Span<const int64_t> multi_index) {
94 return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
95 CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
96 }
97 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64_t> multi_index)98 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
99 absl::Span<const int64_t> multi_index) {
100 return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
101 CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
102 }
103
104 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64_t> multi_index)105 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
106 absl::Span<const int64_t> multi_index) {
107 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
108 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
109 auto lhs_double = static_cast<double>(lhs);
110 auto rhs_double = static_cast<double>(rhs);
111 return InvalidArgument(
112 "floating values are not bitwise-equal; and equality testing "
113 "was requested: %s=%s=%a vs %s=%s=%a at array index %s",
114 StrCat(absl::Hex(ulhs)), RoundTripFpToString(lhs), lhs_double,
115 StrCat(absl::Hex(urhs)), RoundTripFpToString(rhs), rhs_double,
116 LiteralUtil::MultiIndexAsString(multi_index));
117 }
118
119 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64_t> multi_index)120 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
121 absl::Span<const int64_t> multi_index) {
122 return InvalidArgument(
123 "first mismatch at array index %s:\n expected value: %s\n actual "
124 "value: %s",
125 LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
126 }
127
128 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64_t> multi_index)129 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
130 absl::Span<const int64_t> multi_index) {
131 return MakeBitwiseErrorStatus<bfloat16, uint16_t>(lhs, rhs, multi_index);
132 }
133 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64_t> multi_index)134 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
135 absl::Span<const int64_t> multi_index) {
136 return MakeBitwiseErrorStatus<Eigen::half, uint16_t>(lhs, rhs, multi_index);
137 }
138 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64_t> multi_index)139 Status MakeErrorStatus(float lhs, float rhs,
140 absl::Span<const int64_t> multi_index) {
141 return MakeBitwiseErrorStatus<float, uint32_t>(lhs, rhs, multi_index);
142 }
143 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64_t> multi_index)144 Status MakeErrorStatus(double lhs, double rhs,
145 absl::Span<const int64_t> multi_index) {
146 return MakeBitwiseErrorStatus<double, uint64_t>(lhs, rhs, multi_index);
147 }
148 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64_t> multi_index)149 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
150 absl::Span<const int64_t> multi_index) {
151 if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
152 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
153 }
154 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
155 }
156 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64_t> multi_index)157 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
158 absl::Span<const int64_t> multi_index) {
159 if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
160 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
161 }
162 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
163 }
164
165 // A recursive function which iterates through every index of expected and
166 // actual literal and compares their values elementwise. Returns true if all
167 // elements are equal. Mismatched must either be:
168 // - a literal of booleans that has the same shape as expected and actual. In
169 // this case, each index in mismatched will be set to true if expected does
170 // not equal actual at that index and false if there are equal.
171 // - nullptr. In this case, the function will return once any mismatch is
172 // found between expected and actual.
173 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64_t> multi_index,int64_t dimension,Literal * mismatched=nullptr)174 Status Equal(LiteralSlice expected, LiteralSlice actual,
175 absl::Span<int64_t> multi_index, int64_t dimension,
176 Literal* mismatched = nullptr) {
177 if (dimension == expected.shape().dimensions_size()) {
178 NativeT expected_value = expected.Get<NativeT>(multi_index);
179 NativeT actual_value = actual.Get<NativeT>(multi_index);
180 bool result =
181 CompareEqual<NativeT>(expected_value, actual_value, multi_index);
182 if (mismatched) {
183 mismatched->Set<bool>(multi_index, !result);
184 }
185 return result ? OkStatus()
186 : MakeErrorStatus<NativeT>(expected_value, actual_value,
187 multi_index);
188 }
189
190 Status result;
191 for (int64_t i = 0; i < expected.shape().dimensions(dimension); ++i) {
192 multi_index[dimension] = i;
193 if (mismatched != nullptr) {
194 result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
195 mismatched));
196 } else {
197 TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
198 dimension + 1, mismatched));
199 }
200 }
201 return result;
202 }
203
204 // Gets the total element count. For tuples, this is not the count of tuple
205 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)206 int64_t RecursiveElementCount(const Shape& shape) {
207 if (shape.IsTuple()) {
208 const int64_t tuple_elements = ShapeUtil::TupleElementCount(shape);
209 int64_t total = 0;
210 for (int64_t i = 0; i < tuple_elements; ++i) {
211 total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
212 }
213 return total;
214 } else if (shape.IsArray()) {
215 return ShapeUtil::ElementsIn(shape);
216 } else {
217 return 0;
218 }
219 }
220
221 // Returns whether the given value is infinity.
222 template <typename NativeT>
IsInf(NativeT val)223 bool IsInf(NativeT val) {
224 return Eigen::numext::isinf(val);
225 }
226 // Returns whether the given value is nan.
227 template <typename NativeT>
IsNan(NativeT value)228 bool IsNan(NativeT value) {
229 return Eigen::numext::isnan(value);
230 }
231
232 // Converts the given floating-point value to a string.
FpValueToString(bfloat16 value)233 std::string FpValueToString(bfloat16 value) {
234 return absl::StrFormat("%10.4g", static_cast<double>(value));
235 }
236
FpValueToString(half value)237 std::string FpValueToString(half value) {
238 return absl::StrFormat("%11.5g", static_cast<double>(value));
239 }
240
FpValueToString(float value)241 std::string FpValueToString(float value) {
242 return absl::StrFormat("%15.9g", static_cast<double>(value));
243 }
244
FpValueToString(double value)245 std::string FpValueToString(double value) {
246 return absl::StrFormat("%24.17g", value);
247 }
248
FpValueToString(complex64 value)249 std::string FpValueToString(complex64 value) {
250 return absl::StrCat(FpValueToString(value.real()), " + ",
251 FpValueToString(value.imag()));
252 }
253
FpValueToString(complex128 value)254 std::string FpValueToString(complex128 value) {
255 return absl::StrCat(FpValueToString(value.real()), " + ",
256 FpValueToString(value.imag()));
257 }
258
259 // A wrapper of std::abs to include data types that are not supported by
260 // std::abs, in particular, bfloat16 and half.
261 template <typename NativeT>
FpAbsoluteValue(NativeT value)262 double FpAbsoluteValue(NativeT value) {
263 return std::abs(value);
264 }
265
266 template <>
FpAbsoluteValue(bfloat16 value)267 double FpAbsoluteValue(bfloat16 value) {
268 return FpAbsoluteValue<float>(static_cast<float>(value));
269 }
270
271 template <>
FpAbsoluteValue(half value)272 double FpAbsoluteValue(half value) {
273 return FpAbsoluteValue<float>(static_cast<float>(value));
274 }
275
276 // Helper class for comparing floating-point literals within an error bound.
277 template <typename NativeT>
278 class NearComparator {
279 public:
280 // Compares the two array literals elementwise and returns a comparison
281 // result. The comparison is ok() if all actual and expected elements are
282 // within the given error bound. In case of error, the status contains a
283 // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)284 static Status Compare(const LiteralSlice& expected,
285 const LiteralSlice& actual,
286 const ShapeIndex& shape_index, ErrorSpec error,
287 bool detailed_message,
288 const MiscompareCallback& miscompare_callback) {
289 NearComparator<NativeT> comparator(expected, actual, shape_index, error,
290 detailed_message, miscompare_callback);
291 return comparator.Run();
292 }
293
294 private:
295 // Data structure encapsulating metadata about a single element mismatch.
296 struct Mismatch {
297 NativeT actual;
298 NativeT expected;
299 double rel_error;
300 double abs_error;
301
302 // The linear index of the failure within the shape. This linear index is
303 // from the 'actual' literal.
304 int64_t linear_index;
305
operator <xla::literal_comparison::__anonbb060ba40111::NearComparator::Mismatch306 bool operator<(const Mismatch& other) const {
307 return rel_error < other.rel_error;
308 }
309
ToStringxla::literal_comparison::__anonbb060ba40111::NearComparator::Mismatch310 std::string ToString(const Shape& shape) const {
311 return absl::StrFormat(
312 "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
313 FpValueToString(actual), FpValueToString(expected),
314 LiteralUtil::MultiIndexAsString(
315 IndexUtil::LinearIndexToMultidimensionalIndex(shape,
316 linear_index)),
317 rel_error, abs_error);
318 }
319 };
320
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)321 NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
322 const ShapeIndex& shape_index, ErrorSpec error,
323 bool detailed_message,
324 const MiscompareCallback& miscompare_callback)
325 : expected_(expected),
326 actual_(actual),
327 shape_index_(shape_index),
328 error_(error),
329 detailed_message_(detailed_message),
330 miscompare_callback_(miscompare_callback),
331 abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
332 abs_error_buckets_(kErrorBucketBounds.size(), 0),
333 rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
334
335 // Runs the comparison between expected and actual literals.
Run()336 Status Run() {
337 // If the shapes mismatch, we simply fail the expectation instead of
338 // printing out data, as it's a type error rather than a value error.
339 TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
340 if (!expected_.shape().IsArray()) {
341 return InvalidArgument("Expected array shape; got %s.",
342 ShapeUtil::HumanString(expected_.shape()));
343 }
344
345 mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
346 mismatches_.PopulateWithValue(false);
347
348 CompareLiterals();
349
350 if (num_mismatches_ == 0) {
351 return OkStatus();
352 } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
353 miscompare_callback_(
354 expected_, actual_, mismatches_, shape_index_,
355 ErrorBuckets(abs_error_buckets_, rel_error_buckets_));
356 }
357 return InvalidArgument("%s", ErrorMessage());
358 }
359
360 // Insert the given absolute value into the absolute value bucket vector. The
361 // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)362 void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
363 // Adjust the bucket containing the absolute values of the 'actual'
364 // elements.
365 const double abs_value = FpAbsoluteValue(value);
366 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
367 if (i == abs_value_buckets_.size() - 1 ||
368 (abs_value >= kAbsValueBucketBounds[i] &&
369 abs_value < kAbsValueBucketBounds[i + 1])) {
370 // The first value of the pair is the count of elements in the bucket,
371 // the second is the count of mismatches in the bucket.
372 abs_value_buckets_[i].first++;
373 if (is_mismatch) {
374 abs_value_buckets_[i].second++;
375 }
376 return;
377 }
378 }
379 }
380
381 // Insert the given error into the given error bucket vector.
UpdateErrorBucket(double error,absl::Span<int64_t> error_buckets)382 void UpdateErrorBucket(double error, absl::Span<int64_t> error_buckets) {
383 CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
384 for (int i = 0; i < error_buckets.size(); ++i) {
385 if (error >= kErrorBucketBounds[i]) {
386 error_buckets[i]++;
387 }
388 }
389 }
390
391 // Compares the two given elements from the expected and actual literals at
392 // the given literal_index and keeps track of various mismatch statistics.
393 template <typename T>
CompareValues(T expected,T actual,int64_t linear_index)394 void CompareValues(T expected, T actual, int64_t linear_index) {
395 double abs_error;
396 double rel_error;
397 if (CompareEqual<T>(expected, actual, {linear_index})) {
398 abs_error = 0;
399 rel_error = 0;
400 } else if (IsNan(expected) || IsNan(actual)) {
401 if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
402 (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
403 num_nan_mismatches_++;
404 // A nan mismatch is considered to have infinite error. rel_error is
405 // used for sorting a std::set of the top mismatches, and a nan value
406 // here will result in undefined behavior because nan's do not satisfy
407 // the strict weak ordering requirement of std containers.
408 abs_error = std::numeric_limits<float>::infinity();
409 rel_error = std::numeric_limits<float>::infinity();
410 } else {
411 abs_error = 0;
412 rel_error = 0;
413 }
414 } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
415 // `fewer_infs_ok` gives us the option of comparing as though `actual`
416 // were float_max/min rather than inf.
417 T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
418 : std::numeric_limits<T>::lowest();
419 abs_error = FpAbsoluteValue(actual_finite - expected);
420
421 // Avoid division by 0 even though it's well-defined because ubsan can be
422 // configured to treat this as a fatal error.
423 if (expected != T{0}) {
424 rel_error = abs_error / FpAbsoluteValue(expected);
425 } else {
426 rel_error = std::numeric_limits<float>::infinity();
427 }
428 } else if (IsInf(expected) || IsInf(actual)) {
429 // If either the expected or actual value is infinity but not both,
430 // then both absolute and relative error are regarded as infinity.
431 CHECK(!CompareEqual(expected, actual, {linear_index}));
432 abs_error = std::numeric_limits<float>::infinity();
433 rel_error = std::numeric_limits<float>::infinity();
434 } else {
435 abs_error = FpAbsoluteValue(actual - expected);
436
437 // Avoid division by 0 even though it's well-defined because ubsan can be
438 // configured to treat this as a fatal error.
439 if (expected != T{0}) {
440 rel_error = abs_error / FpAbsoluteValue(expected);
441 } else {
442 rel_error = std::numeric_limits<float>::infinity();
443 }
444 }
445 const bool is_abs_mismatch = abs_error > error_.abs;
446 const bool is_rel_mismatch = rel_error > error_.rel;
447 const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
448
449 // Update the error of the relative bucket only if the *absolute* error
450 // bound is exceeded and vice versa.
451 if (is_abs_mismatch) {
452 num_abs_mismatches_++;
453 UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
454 }
455 if (is_rel_mismatch) {
456 num_rel_mismatches_++;
457 UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
458 }
459
460 UpdateAbsValueBucket(actual, is_mismatch);
461
462 if (!is_mismatch) {
463 return;
464 }
465
466 num_mismatches_++;
467
468 // Keep track of the kTopRelativeErrorCount relative error mismatches.
469 if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
470 rel_error > top_rel_mismatches_.begin()->rel_error) {
471 Mismatch mismatch = {actual, expected, rel_error, abs_error,
472 linear_index};
473 top_rel_mismatches_.insert(mismatch);
474 if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
475 top_rel_mismatches_.erase(top_rel_mismatches_.begin());
476 }
477 }
478
479 mismatches_.data<bool>()[linear_index] = true;
480 }
481
482 // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64_t linear_index)483 void CompareValues(complex64 expected, complex64 actual,
484 int64_t linear_index) {
485 const auto both_parts_mismatch = num_mismatches_ + 2;
486 CompareValues<float>(expected.real(), actual.real(), linear_index);
487 CompareValues<float>(expected.imag(), actual.imag(), linear_index);
488 if (num_mismatches_ == both_parts_mismatch) {
489 // The mismatch counter had been incremented by each CompareValues() call,
490 // which means that both real and imaginary parts of the passed-in complex
491 // values are different. However, the counter should reflect a single
492 // mismatch between these complex values.
493 num_mismatches_--;
494 }
495 }
496
CompareValues(complex128 expected,complex128 actual,int64_t linear_index)497 void CompareValues(complex128 expected, complex128 actual,
498 int64_t linear_index) {
499 const auto both_parts_mismatch = num_mismatches_ + 2;
500 CompareValues<double>(expected.real(), actual.real(), linear_index);
501 CompareValues<double>(expected.imag(), actual.imag(), linear_index);
502 if (num_mismatches_ == both_parts_mismatch) {
503 // The mismatch counter had been incremented by each CompareValues() call,
504 // which means that both real and imaginary parts of the passed-in complex
505 // values are different. However, the counter should reflect a single
506 // mismatch between these complex values.
507 num_mismatches_--;
508 }
509 }
510
511 // Compares the two literals elementwise.
CompareLiterals()512 void CompareLiterals() {
513 // Fast path optimization for the case were layouts match.
514 if (LayoutUtil::Equal(actual_.shape().layout(),
515 expected_.shape().layout())) {
516 absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
517 absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
518 const int64_t len = expected_data.size();
519 for (int64_t i = 0; i < len; ++i) {
520 CompareValues(expected_data[i], actual_data[i], i);
521 }
522 return;
523 }
524 std::vector<int64_t> multi_index(actual_.shape().rank(), 0);
525 CompareLiteralsSlow(0, &multi_index);
526 }
527
528 // Slow path for CompareLiterals when 'actual' and 'expected' literals have
529 // different layouts. In this case, multidimensional indices are constructed
530 // and indexed for each element.
CompareLiteralsSlow(int64_t dimension,std::vector<int64_t> * multi_index)531 void CompareLiteralsSlow(int64_t dimension,
532 std::vector<int64_t>* multi_index) {
533 if (dimension == multi_index->size()) {
534 CompareValues(expected_.Get<NativeT>(*multi_index),
535 actual_.Get<NativeT>(*multi_index),
536 IndexUtil::MultidimensionalIndexToLinearIndex(
537 actual_.shape(), *multi_index));
538 } else {
539 for (int64_t i = 0; i < expected_.shape().dimensions(dimension); ++i) {
540 (*multi_index)[dimension] = i;
541 CompareLiteralsSlow(dimension + 1, multi_index);
542 }
543 }
544 }
545
546 // Returns an error message string with a detailed breakdown of the
547 // mismatches. Called after calling Run().
ErrorMessage()548 std::string ErrorMessage() {
549 std::string out;
550 int64_t element_count = ShapeUtil::ElementsIn(actual_.shape());
551
552 auto percent_string = [](float a, float b) {
553 float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
554 return absl::StrFormat("%0.4f%%", pct);
555 };
556
557 StrAppendFormat(
558 &out,
559 "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
560 "%g, rel bound %g\n",
561 num_mismatches_, percent_string(num_mismatches_, element_count),
562 ShapeUtil::HumanString(actual_.shape()),
563 ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
564 if (num_nan_mismatches_ > 0) {
565 StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
566 }
567 StrAppendFormat(&out, "Top relative error mismatches:\n");
568 for (auto it = top_rel_mismatches_.rbegin();
569 it != top_rel_mismatches_.rend(); ++it) {
570 StrAppend(&out, " ", it->ToString(actual_.shape()), "\n");
571 }
572
573 if (!detailed_message_) {
574 return out;
575 }
576
577 StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
578 CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
579 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
580 const int64_t bucket_size = abs_value_buckets_[i].first;
581 const int64_t bucket_mismatches = abs_value_buckets_[i].second;
582 std::string mismatch_str =
583 bucket_mismatches > 0
584 ? absl::StrFormat(", mismatches %d", bucket_mismatches)
585 : "";
586 StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n",
587 kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
588 bucket_size, percent_string(bucket_size, element_count),
589 mismatch_str);
590 }
591
592 auto print_accum_buckets = [&](const std::string& header, int64_t total,
593 absl::Span<const int64_t> buckets) {
594 StrAppend(&out, header, ":\n");
595 StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
596 total - buckets[0],
597 percent_string(total - buckets[0], total));
598 CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
599 for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
600 StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
601 buckets[i], percent_string(buckets[i], total));
602 }
603 };
604 StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
605 error_.abs, num_abs_mismatches_,
606 percent_string(num_abs_mismatches_, element_count));
607 print_accum_buckets(
608 "Relative error breakdown of elements exceeding abs error bound",
609 num_abs_mismatches_, rel_error_buckets_);
610 StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
611 error_.rel, num_rel_mismatches_,
612 percent_string(num_rel_mismatches_, element_count));
613 print_accum_buckets(
614 "Absolute error breakdown of elements exceeding rel error bound",
615 num_rel_mismatches_, abs_error_buckets_);
616 return out;
617 }
618
619 // 'actual' and 'expected' literals being compared.
620 LiteralSlice expected_;
621 LiteralSlice actual_;
622
623 // The shape index of the LiteralSlice that is being compared.
624 ShapeIndex shape_index_;
625
626 // The error bounds of the comparison.
627 ErrorSpec error_;
628
629 // Whether to include detailed breakdown of mismatches in the error message.
630 bool detailed_message_;
631
632 // Callback to invoke on miscompare.
633 MiscompareCallback miscompare_callback_;
634
635 // Number of element mismatches encountered so far.
636 int64_t num_mismatches_ = 0;
637
638 // Number of elements with a nan mismatch.
639 int64_t num_nan_mismatches_ = 0;
640
641 // Number of elements which exceed the absolute/relative error bound.
642 int64_t num_abs_mismatches_ = 0;
643 int64_t num_rel_mismatches_ = 0;
644
645 // A Literal containing which elements did not match in the expected and
646 // actual literals. mismatches_ contains PREDs and is of the same sizes as
647 // the comparison literals.
648 Literal mismatches_;
649
650 // The number of mismatches to report in the output, sorted by relative error
651 // magnitude.
652 static constexpr int64_t kTopRelativeErrorCount = 5;
653
654 // The set of mismatches with the largest relative error. The size of this set
655 // is bounded by kTopRelativeErrorCount.
656 std::multiset<Mismatch> top_rel_mismatches_;
657
658 // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
659 // bounds of these buckets. abs_value_buckets_ contains a pair for each
660 // bucket: the element count and failure count.
661 static constexpr std::array<float, 7> kAbsValueBucketBounds = {
662 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
663 std::vector<std::pair<int64_t, int64_t>> abs_value_buckets_;
664
665 // Buckets for relative and absolute errors. The relative error buckets only
666 // contains those elements which exceed the *absolute* error bound, and vice
667 // versa. This makes it easy to see the effect of adjusting the relative (or
668 // absolute) error bound on the success of the comparison. kErrorBucketBounds
669 // are the lower bounds of the buckets in both vectors. The error buckets are
670 // a cumulative distribution so an error value may appear in more than one
671 // bucket. For example an error value of 0.003 may appear in the buckets
672 // bounded by 0.01, 0.1, and 1.0.
673 static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
674 0.01, 0.1, 1};
675 std::vector<int64_t> abs_error_buckets_;
676 std::vector<int64_t> rel_error_buckets_;
677 };
678
679 template <typename NativeT>
680 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
681 template <typename NativeT>
682 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
683
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const MiscompareCallback & miscompare_callback)684 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
685 const ShapeIndex& shape_index,
686 const MiscompareCallback& miscompare_callback) {
687 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
688
689 Status result;
690 if (expected.shape().IsTuple()) {
691 ShapeIndex next_index = shape_index;
692 for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
693 next_index.push_back(i);
694 Status tuple_result =
695 EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
696 next_index, miscompare_callback);
697 if (miscompare_callback) {
698 result.Update(tuple_result);
699 } else {
700 TF_RETURN_IF_ERROR(tuple_result);
701 }
702 next_index.pop_back();
703 }
704 } else {
705 std::vector<int64_t> multi_index(expected.shape().dimensions_size(), 0);
706 auto index = absl::MakeSpan(multi_index);
707
708 Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
709 expected.shape().dimensions());
710 Literal miscompared(unequal_shape);
711 Literal* miscompared_ptr =
712 (miscompare_callback == nullptr ? nullptr : &miscompared);
713
714 switch (expected.shape().element_type()) {
715 case PRED:
716 result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
717 break;
718 case S8:
719 result = Equal<int8_t>(expected, actual, index, 0, miscompared_ptr);
720 break;
721 case S16:
722 result = Equal<int16_t>(expected, actual, index, 0, miscompared_ptr);
723 break;
724 case S32:
725 result = Equal<int32_t>(expected, actual, index, 0, miscompared_ptr);
726 break;
727 case S64:
728 result = Equal<int64_t>(expected, actual, index, 0, miscompared_ptr);
729 break;
730 case U8:
731 result = Equal<uint8_t>(expected, actual, index, 0, miscompared_ptr);
732 break;
733 case U16:
734 result = Equal<uint16_t>(expected, actual, index, 0, miscompared_ptr);
735 break;
736 case U32:
737 result = Equal<uint32_t>(expected, actual, index, 0, miscompared_ptr);
738 break;
739 case U64:
740 result = Equal<uint64_t>(expected, actual, index, 0, miscompared_ptr);
741 break;
742 case BF16:
743 result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
744 break;
745 case F16:
746 result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
747 break;
748 case F32:
749 result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
750 break;
751 case F64:
752 result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
753 break;
754 case C64:
755 result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
756 break;
757 case C128:
758 result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
759 break;
760 case TOKEN:
761 // Tokens have no on-device representation and are trivially equal.
762 return OkStatus();
763 default:
764 LOG(FATAL) << "Unsupported primitive type: "
765 << PrimitiveType_Name(expected.shape().element_type());
766 }
767
768 if (!result.ok() && miscompare_callback) {
769 miscompare_callback(expected, actual, LiteralSlice(miscompared),
770 shape_index, ErrorBuckets());
771 }
772 }
773
774 return result;
775 }
776
777 // Helper function for comparing two literals for nearness. Handles tuple-shapes
778 // via recursion. shape_index is the ShapeIndex of expected (or actual)
779 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const ErrorSpec & error,std::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)780 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
781 const ShapeIndex& shape_index, const ErrorSpec& error,
782 std::optional<bool> detailed_message,
783 const MiscompareCallback& miscompare_callback) {
784 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
785
786 if (expected.shape().IsTuple()) {
787 Status return_status;
788 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(expected.shape());
789 ++i) {
790 const auto expected_element = LiteralSlice(expected, {i});
791 const auto actual_element = LiteralSlice(actual, {i});
792 ShapeIndex element_index = shape_index;
793 element_index.push_back(i);
794 Status element_result =
795 NearHelper(expected_element, actual_element, element_index, error,
796 detailed_message, miscompare_callback);
797 if (!element_result.ok()) {
798 element_result = InvalidArgument("Array at shape index %s, %s",
799 element_index.ToString(),
800 element_result.error_message());
801 if (return_status.ok()) {
802 return_status = element_result;
803 } else {
804 return_status =
805 AppendStatus(return_status, element_result.error_message());
806 }
807 }
808 }
809 if (!return_status.ok() && shape_index.empty()) {
810 // Emit a top-level error message containing the top-level shape in case
811 // of mismatch.
812 int64_t total_elements = RecursiveElementCount(actual.shape());
813 return_status =
814 InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
815 ShapeUtil::HumanString(actual.shape()),
816 total_elements, return_status.error_message());
817 }
818 return return_status;
819 }
820
821 if (ShapeUtil::ElementIsFloating(expected.shape()) ||
822 ShapeUtil::ElementIsComplex(expected.shape())) {
823 bool use_detailed_message = detailed_message.value_or(
824 ShapeUtil::ElementsIn(expected.shape()) >= 64);
825 switch (expected.shape().element_type()) {
826 case BF16:
827 return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
828 error, use_detailed_message,
829 miscompare_callback);
830 break;
831 case F16:
832 return NearComparator<half>::Compare(expected, actual, shape_index,
833 error, use_detailed_message,
834 miscompare_callback);
835 break;
836 case F32:
837 return NearComparator<float>::Compare(expected, actual, shape_index,
838 error, use_detailed_message,
839 miscompare_callback);
840 break;
841 case F64:
842 return NearComparator<double>::Compare(expected, actual, shape_index,
843 error, use_detailed_message,
844 miscompare_callback);
845 break;
846 case C64:
847 return NearComparator<complex64>::Compare(expected, actual, shape_index,
848 error, use_detailed_message,
849 miscompare_callback);
850 break;
851 case C128:
852 return NearComparator<complex128>::Compare(
853 expected, actual, shape_index, error, use_detailed_message,
854 miscompare_callback);
855 break;
856 default:
857 LOG(FATAL) << "Unsupported primitive type in near comparator: "
858 << PrimitiveType_Name(expected.shape().element_type())
859 << ". Must be floating-point type.";
860 }
861 }
862
863 // Non-floating point, non-tuple literal.
864 return EqualHelper(expected, actual, shape_index, miscompare_callback);
865 }
866
867 } // namespace
868
EqualShapes(const Shape & expected,const Shape & actual)869 Status EqualShapes(const Shape& expected, const Shape& actual) {
870 if (expected.element_type() != actual.element_type()) {
871 return InvalidArgument("element type mismatch, want: %s got %s",
872 ShapeUtil::HumanString(expected),
873 ShapeUtil::HumanString(actual));
874 }
875 if (expected.IsTuple()) {
876 if (ShapeUtil::TupleElementCount(expected) !=
877 ShapeUtil::TupleElementCount(actual)) {
878 return InvalidArgument(
879 "want tuple element count: %d got tuple element count: %d",
880 ShapeUtil::TupleElementCount(expected),
881 ShapeUtil::TupleElementCount(actual));
882 }
883 for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
884 Status result =
885 EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
886 if (!result.ok()) {
887 return AppendStatus(result, StrCat("mismatch in tuple index", i));
888 }
889 }
890 } else if (expected.IsArray()) {
891 if (expected.rank() != actual.rank()) {
892 return InvalidArgument("want rank of %s got rank of %s",
893 ShapeUtil::HumanString(expected),
894 ShapeUtil::HumanString(actual));
895 }
896 if (expected.element_type() != actual.element_type()) {
897 return InvalidArgument("mismatch in primitive type %s vs %s",
898 PrimitiveType_Name(expected.element_type()),
899 PrimitiveType_Name(actual.element_type()));
900 }
901 if (expected.dimensions_size() != actual.dimensions_size()) {
902 return InvalidArgument("want dimensions_size %d got dimensions_size %d",
903 expected.dimensions_size(),
904 actual.dimensions_size());
905 }
906 for (int i = 0; i < expected.dimensions_size(); ++i) {
907 if (expected.dimensions(i) != actual.dimensions(i)) {
908 return InvalidArgument(
909 "mismatch in dimension #%d expected: %s actual: %s", i,
910 ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
911 }
912 }
913 }
914 // Non-array, non-tuple shapes are trivially equivalent.
915 return OkStatus();
916 }
917
918 namespace {
919
920 // If result is an error, extend the error message with the expected and actual
921 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)922 Status EmitLiteralsInErrorMessage(const Status& result,
923 const LiteralSlice& expected,
924 const LiteralSlice& actual) {
925 if (result.ok()) {
926 return result;
927 }
928 return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
929 result.error_message(), ToStringTruncated(expected),
930 ToStringTruncated(actual));
931 }
932
933 } // namespace
934
Equal(const LiteralSlice & expected,const LiteralSlice & actual)935 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
936 VLOG(1) << "expected:";
937 XLA_VLOG_LINES(1, expected.ToString());
938 VLOG(1) << "actual:";
939 XLA_VLOG_LINES(1, actual.ToString());
940 Status result = EqualHelper(expected, actual, {}, nullptr);
941 return EmitLiteralsInErrorMessage(result, expected, actual);
942 }
943
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,std::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)944 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
945 const ErrorSpec& error, std::optional<bool> detailed_message,
946 const MiscompareCallback& miscompare_callback) {
947 VLOG(1) << "Expected literal:";
948 XLA_VLOG_LINES(1, expected.ToString());
949 VLOG(1) << "Actual literal:";
950 XLA_VLOG_LINES(1, actual.ToString());
951 Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
952 detailed_message, miscompare_callback);
953 return EmitLiteralsInErrorMessage(result, expected, actual);
954 }
955
ToStringTruncated(const LiteralSlice & literal)956 std::string ToStringTruncated(const LiteralSlice& literal) {
957 return RecursiveElementCount(literal.shape()) < 1000
958 ? literal.ToString()
959 : "[TRUNCATED, Literal with more than 1000 values]";
960 }
961
962 } // namespace literal_comparison
963 } // namespace xla
964