xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/literal_comparison.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17 
18 #include <unistd.h>
19 
20 #include <cmath>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/platform/env.h"
29 
30 using absl::StrAppend;
31 using absl::StrAppendFormat;
32 using absl::StrCat;
33 
34 namespace xla {
35 namespace literal_comparison {
36 namespace {
37 
38 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
39 // able to transparently access the raw 16-bit value contained within.
40 template <typename T>
GetRawValue(T val)41 T GetRawValue(T val) {
42   return val;
43 }
GetRawValue(Eigen::half val)44 uint16_t GetRawValue(Eigen::half val) {
45   return Eigen::numext::bit_cast<uint16_t>(val);
46 }
47 
48 // Helper function for comparing a floating point type, FloatT, bitwise equal
49 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
50 // -- on miscompare, a nice error message is given in the AssertionFailure.
51 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64_t> multi_index)52 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
53                                absl::Span<const int64_t> multi_index) {
54   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
55   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
56   return ulhs == urhs;
57 }
58 
59 // Templated comparator that specializes for float equality comparison with the
60 // bitwise helper above (this is the un-specialized fallback, to just use the
61 // default gunit implementation).
62 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64_t> multi_index)63 bool CompareEqual(NativeT lhs, NativeT rhs,
64                   absl::Span<const int64_t> multi_index) {
65   return lhs == rhs;
66 }
67 
68 // Specializations for floating types that do bitwise comparisons when equality
69 // comparison is requested.
70 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64_t> multi_index)71 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
72                             absl::Span<const int64_t> multi_index) {
73   return CompareFloatsBitwiseEqual<bfloat16, uint16_t>(lhs, rhs, multi_index);
74 }
75 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64_t> multi_index)76 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
77                                absl::Span<const int64_t> multi_index) {
78   return CompareFloatsBitwiseEqual<Eigen::half, uint16_t>(lhs, rhs,
79                                                           multi_index);
80 }
81 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64_t> multi_index)82 bool CompareEqual<float>(float lhs, float rhs,
83                          absl::Span<const int64_t> multi_index) {
84   return CompareFloatsBitwiseEqual<float, uint32_t>(lhs, rhs, multi_index);
85 }
86 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64_t> multi_index)87 bool CompareEqual<double>(double lhs, double rhs,
88                           absl::Span<const int64_t> multi_index) {
89   return CompareFloatsBitwiseEqual<double, uint64_t>(lhs, rhs, multi_index);
90 }
91 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64_t> multi_index)92 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
93                              absl::Span<const int64_t> multi_index) {
94   return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
95          CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
96 }
97 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64_t> multi_index)98 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
99                               absl::Span<const int64_t> multi_index) {
100   return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
101          CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
102 }
103 
104 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64_t> multi_index)105 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
106                               absl::Span<const int64_t> multi_index) {
107   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
108   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
109   auto lhs_double = static_cast<double>(lhs);
110   auto rhs_double = static_cast<double>(rhs);
111   return InvalidArgument(
112       "floating values are not bitwise-equal; and equality testing "
113       "was requested: %s=%s=%a vs %s=%s=%a at array index %s",
114       StrCat(absl::Hex(ulhs)), RoundTripFpToString(lhs), lhs_double,
115       StrCat(absl::Hex(urhs)), RoundTripFpToString(rhs), rhs_double,
116       LiteralUtil::MultiIndexAsString(multi_index));
117 }
118 
119 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64_t> multi_index)120 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
121                        absl::Span<const int64_t> multi_index) {
122   return InvalidArgument(
123       "first mismatch at array index %s:\n  expected value: %s\n  actual "
124       "value:   %s",
125       LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
126 }
127 
128 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64_t> multi_index)129 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
130                        absl::Span<const int64_t> multi_index) {
131   return MakeBitwiseErrorStatus<bfloat16, uint16_t>(lhs, rhs, multi_index);
132 }
133 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64_t> multi_index)134 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
135                        absl::Span<const int64_t> multi_index) {
136   return MakeBitwiseErrorStatus<Eigen::half, uint16_t>(lhs, rhs, multi_index);
137 }
138 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64_t> multi_index)139 Status MakeErrorStatus(float lhs, float rhs,
140                        absl::Span<const int64_t> multi_index) {
141   return MakeBitwiseErrorStatus<float, uint32_t>(lhs, rhs, multi_index);
142 }
143 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64_t> multi_index)144 Status MakeErrorStatus(double lhs, double rhs,
145                        absl::Span<const int64_t> multi_index) {
146   return MakeBitwiseErrorStatus<double, uint64_t>(lhs, rhs, multi_index);
147 }
148 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64_t> multi_index)149 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
150                        absl::Span<const int64_t> multi_index) {
151   if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
152     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
153   }
154   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
155 }
156 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64_t> multi_index)157 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
158                        absl::Span<const int64_t> multi_index) {
159   if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
160     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
161   }
162   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
163 }
164 
165 // A recursive function which iterates through every index of expected and
166 // actual literal and compares their values elementwise. Returns true if all
167 // elements are equal. Mismatched must either be:
168 //    - a literal of booleans that has the same shape as expected and actual. In
169 //      this case, each index in mismatched will be set to true if expected does
170 //      not equal actual at that index and false if there are equal.
171 //    - nullptr. In this case, the function will return once any mismatch is
172 //      found between expected and actual.
173 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64_t> multi_index,int64_t dimension,Literal * mismatched=nullptr)174 Status Equal(LiteralSlice expected, LiteralSlice actual,
175              absl::Span<int64_t> multi_index, int64_t dimension,
176              Literal* mismatched = nullptr) {
177   if (dimension == expected.shape().dimensions_size()) {
178     NativeT expected_value = expected.Get<NativeT>(multi_index);
179     NativeT actual_value = actual.Get<NativeT>(multi_index);
180     bool result =
181         CompareEqual<NativeT>(expected_value, actual_value, multi_index);
182     if (mismatched) {
183       mismatched->Set<bool>(multi_index, !result);
184     }
185     return result ? OkStatus()
186                   : MakeErrorStatus<NativeT>(expected_value, actual_value,
187                                              multi_index);
188   }
189 
190   Status result;
191   for (int64_t i = 0; i < expected.shape().dimensions(dimension); ++i) {
192     multi_index[dimension] = i;
193     if (mismatched != nullptr) {
194       result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
195                                    mismatched));
196     } else {
197       TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
198                                         dimension + 1, mismatched));
199     }
200   }
201   return result;
202 }
203 
204 // Gets the total element count.  For tuples, this is not the count of tuple
205 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)206 int64_t RecursiveElementCount(const Shape& shape) {
207   if (shape.IsTuple()) {
208     const int64_t tuple_elements = ShapeUtil::TupleElementCount(shape);
209     int64_t total = 0;
210     for (int64_t i = 0; i < tuple_elements; ++i) {
211       total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
212     }
213     return total;
214   } else if (shape.IsArray()) {
215     return ShapeUtil::ElementsIn(shape);
216   } else {
217     return 0;
218   }
219 }
220 
221 // Returns whether the given value is infinity.
222 template <typename NativeT>
IsInf(NativeT val)223 bool IsInf(NativeT val) {
224   return Eigen::numext::isinf(val);
225 }
226 // Returns whether the given value is nan.
227 template <typename NativeT>
IsNan(NativeT value)228 bool IsNan(NativeT value) {
229   return Eigen::numext::isnan(value);
230 }
231 
232 // Converts the given floating-point value to a string.
FpValueToString(bfloat16 value)233 std::string FpValueToString(bfloat16 value) {
234   return absl::StrFormat("%10.4g", static_cast<double>(value));
235 }
236 
FpValueToString(half value)237 std::string FpValueToString(half value) {
238   return absl::StrFormat("%11.5g", static_cast<double>(value));
239 }
240 
FpValueToString(float value)241 std::string FpValueToString(float value) {
242   return absl::StrFormat("%15.9g", static_cast<double>(value));
243 }
244 
FpValueToString(double value)245 std::string FpValueToString(double value) {
246   return absl::StrFormat("%24.17g", value);
247 }
248 
FpValueToString(complex64 value)249 std::string FpValueToString(complex64 value) {
250   return absl::StrCat(FpValueToString(value.real()), " + ",
251                       FpValueToString(value.imag()));
252 }
253 
FpValueToString(complex128 value)254 std::string FpValueToString(complex128 value) {
255   return absl::StrCat(FpValueToString(value.real()), " + ",
256                       FpValueToString(value.imag()));
257 }
258 
259 // A wrapper of std::abs to include data types that are not supported by
260 // std::abs, in particular, bfloat16 and half.
261 template <typename NativeT>
FpAbsoluteValue(NativeT value)262 double FpAbsoluteValue(NativeT value) {
263   return std::abs(value);
264 }
265 
266 template <>
FpAbsoluteValue(bfloat16 value)267 double FpAbsoluteValue(bfloat16 value) {
268   return FpAbsoluteValue<float>(static_cast<float>(value));
269 }
270 
271 template <>
FpAbsoluteValue(half value)272 double FpAbsoluteValue(half value) {
273   return FpAbsoluteValue<float>(static_cast<float>(value));
274 }
275 
276 // Helper class for comparing floating-point literals within an error bound.
277 template <typename NativeT>
278 class NearComparator {
279  public:
280   // Compares the two array literals elementwise and returns a comparison
281   // result. The comparison is ok() if all actual and expected elements are
282   // within the given error bound. In case of error, the status contains a
283   // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)284   static Status Compare(const LiteralSlice& expected,
285                         const LiteralSlice& actual,
286                         const ShapeIndex& shape_index, ErrorSpec error,
287                         bool detailed_message,
288                         const MiscompareCallback& miscompare_callback) {
289     NearComparator<NativeT> comparator(expected, actual, shape_index, error,
290                                        detailed_message, miscompare_callback);
291     return comparator.Run();
292   }
293 
294  private:
295   // Data structure encapsulating metadata about a single element mismatch.
296   struct Mismatch {
297     NativeT actual;
298     NativeT expected;
299     double rel_error;
300     double abs_error;
301 
302     // The linear index of the failure within the shape. This linear index is
303     // from the 'actual' literal.
304     int64_t linear_index;
305 
operator <xla::literal_comparison::__anonbb060ba40111::NearComparator::Mismatch306     bool operator<(const Mismatch& other) const {
307       return rel_error < other.rel_error;
308     }
309 
ToStringxla::literal_comparison::__anonbb060ba40111::NearComparator::Mismatch310     std::string ToString(const Shape& shape) const {
311       return absl::StrFormat(
312           "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
313           FpValueToString(actual), FpValueToString(expected),
314           LiteralUtil::MultiIndexAsString(
315               IndexUtil::LinearIndexToMultidimensionalIndex(shape,
316                                                             linear_index)),
317           rel_error, abs_error);
318     }
319   };
320 
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)321   NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
322                  const ShapeIndex& shape_index, ErrorSpec error,
323                  bool detailed_message,
324                  const MiscompareCallback& miscompare_callback)
325       : expected_(expected),
326         actual_(actual),
327         shape_index_(shape_index),
328         error_(error),
329         detailed_message_(detailed_message),
330         miscompare_callback_(miscompare_callback),
331         abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
332         abs_error_buckets_(kErrorBucketBounds.size(), 0),
333         rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
334 
335   // Runs the comparison between expected and actual literals.
Run()336   Status Run() {
337     // If the shapes mismatch, we simply fail the expectation instead of
338     // printing out data, as it's a type error rather than a value error.
339     TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
340     if (!expected_.shape().IsArray()) {
341       return InvalidArgument("Expected array shape; got %s.",
342                              ShapeUtil::HumanString(expected_.shape()));
343     }
344 
345     mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
346     mismatches_.PopulateWithValue(false);
347 
348     CompareLiterals();
349 
350     if (num_mismatches_ == 0) {
351       return OkStatus();
352     } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
353       miscompare_callback_(
354           expected_, actual_, mismatches_, shape_index_,
355           ErrorBuckets(abs_error_buckets_, rel_error_buckets_));
356     }
357     return InvalidArgument("%s", ErrorMessage());
358   }
359 
360   // Insert the given absolute value into the absolute value bucket vector. The
361   // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)362   void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
363     // Adjust the bucket containing the absolute values of the 'actual'
364     // elements.
365     const double abs_value = FpAbsoluteValue(value);
366     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
367       if (i == abs_value_buckets_.size() - 1 ||
368           (abs_value >= kAbsValueBucketBounds[i] &&
369            abs_value < kAbsValueBucketBounds[i + 1])) {
370         // The first value of the pair is the count of elements in the bucket,
371         // the second is the count of mismatches in the bucket.
372         abs_value_buckets_[i].first++;
373         if (is_mismatch) {
374           abs_value_buckets_[i].second++;
375         }
376         return;
377       }
378     }
379   }
380 
381   // Insert the given error into the given error bucket vector.
UpdateErrorBucket(double error,absl::Span<int64_t> error_buckets)382   void UpdateErrorBucket(double error, absl::Span<int64_t> error_buckets) {
383     CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
384     for (int i = 0; i < error_buckets.size(); ++i) {
385       if (error >= kErrorBucketBounds[i]) {
386         error_buckets[i]++;
387       }
388     }
389   }
390 
391   // Compares the two given elements from the expected and actual literals at
392   // the given literal_index and keeps track of various mismatch statistics.
393   template <typename T>
CompareValues(T expected,T actual,int64_t linear_index)394   void CompareValues(T expected, T actual, int64_t linear_index) {
395     double abs_error;
396     double rel_error;
397     if (CompareEqual<T>(expected, actual, {linear_index})) {
398       abs_error = 0;
399       rel_error = 0;
400     } else if (IsNan(expected) || IsNan(actual)) {
401       if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
402           (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
403         num_nan_mismatches_++;
404         // A nan mismatch is considered to have infinite error. rel_error is
405         // used for sorting a std::set of the top mismatches, and a nan value
406         // here will result in undefined behavior because nan's do not satisfy
407         // the strict weak ordering requirement of std containers.
408         abs_error = std::numeric_limits<float>::infinity();
409         rel_error = std::numeric_limits<float>::infinity();
410       } else {
411         abs_error = 0;
412         rel_error = 0;
413       }
414     } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
415       // `fewer_infs_ok` gives us the option of comparing as though `actual`
416       // were float_max/min rather than inf.
417       T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
418                                       : std::numeric_limits<T>::lowest();
419       abs_error = FpAbsoluteValue(actual_finite - expected);
420 
421       // Avoid division by 0 even though it's well-defined because ubsan can be
422       // configured to treat this as a fatal error.
423       if (expected != T{0}) {
424         rel_error = abs_error / FpAbsoluteValue(expected);
425       } else {
426         rel_error = std::numeric_limits<float>::infinity();
427       }
428     } else if (IsInf(expected) || IsInf(actual)) {
429       // If either the expected or actual value is infinity but not both,
430       // then both absolute and relative error are regarded as infinity.
431       CHECK(!CompareEqual(expected, actual, {linear_index}));
432       abs_error = std::numeric_limits<float>::infinity();
433       rel_error = std::numeric_limits<float>::infinity();
434     } else {
435       abs_error = FpAbsoluteValue(actual - expected);
436 
437       // Avoid division by 0 even though it's well-defined because ubsan can be
438       // configured to treat this as a fatal error.
439       if (expected != T{0}) {
440         rel_error = abs_error / FpAbsoluteValue(expected);
441       } else {
442         rel_error = std::numeric_limits<float>::infinity();
443       }
444     }
445     const bool is_abs_mismatch = abs_error > error_.abs;
446     const bool is_rel_mismatch = rel_error > error_.rel;
447     const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
448 
449     // Update the error of the relative bucket only if the *absolute* error
450     // bound is exceeded and vice versa.
451     if (is_abs_mismatch) {
452       num_abs_mismatches_++;
453       UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
454     }
455     if (is_rel_mismatch) {
456       num_rel_mismatches_++;
457       UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
458     }
459 
460     UpdateAbsValueBucket(actual, is_mismatch);
461 
462     if (!is_mismatch) {
463       return;
464     }
465 
466     num_mismatches_++;
467 
468     // Keep track of the kTopRelativeErrorCount relative error mismatches.
469     if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
470         rel_error > top_rel_mismatches_.begin()->rel_error) {
471       Mismatch mismatch = {actual, expected, rel_error, abs_error,
472                            linear_index};
473       top_rel_mismatches_.insert(mismatch);
474       if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
475         top_rel_mismatches_.erase(top_rel_mismatches_.begin());
476       }
477     }
478 
479     mismatches_.data<bool>()[linear_index] = true;
480   }
481 
482   // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64_t linear_index)483   void CompareValues(complex64 expected, complex64 actual,
484                      int64_t linear_index) {
485     const auto both_parts_mismatch = num_mismatches_ + 2;
486     CompareValues<float>(expected.real(), actual.real(), linear_index);
487     CompareValues<float>(expected.imag(), actual.imag(), linear_index);
488     if (num_mismatches_ == both_parts_mismatch) {
489       // The mismatch counter had been incremented by each CompareValues() call,
490       // which means that both real and imaginary parts of the passed-in complex
491       // values are different. However, the counter should reflect a single
492       // mismatch between these complex values.
493       num_mismatches_--;
494     }
495   }
496 
CompareValues(complex128 expected,complex128 actual,int64_t linear_index)497   void CompareValues(complex128 expected, complex128 actual,
498                      int64_t linear_index) {
499     const auto both_parts_mismatch = num_mismatches_ + 2;
500     CompareValues<double>(expected.real(), actual.real(), linear_index);
501     CompareValues<double>(expected.imag(), actual.imag(), linear_index);
502     if (num_mismatches_ == both_parts_mismatch) {
503       // The mismatch counter had been incremented by each CompareValues() call,
504       // which means that both real and imaginary parts of the passed-in complex
505       // values are different. However, the counter should reflect a single
506       // mismatch between these complex values.
507       num_mismatches_--;
508     }
509   }
510 
511   // Compares the two literals elementwise.
CompareLiterals()512   void CompareLiterals() {
513     // Fast path optimization for the case were layouts match.
514     if (LayoutUtil::Equal(actual_.shape().layout(),
515                           expected_.shape().layout())) {
516       absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
517       absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
518       const int64_t len = expected_data.size();
519       for (int64_t i = 0; i < len; ++i) {
520         CompareValues(expected_data[i], actual_data[i], i);
521       }
522       return;
523     }
524     std::vector<int64_t> multi_index(actual_.shape().rank(), 0);
525     CompareLiteralsSlow(0, &multi_index);
526   }
527 
528   // Slow path for CompareLiterals when 'actual' and 'expected' literals have
529   // different layouts. In this case, multidimensional indices are constructed
530   // and indexed for each element.
CompareLiteralsSlow(int64_t dimension,std::vector<int64_t> * multi_index)531   void CompareLiteralsSlow(int64_t dimension,
532                            std::vector<int64_t>* multi_index) {
533     if (dimension == multi_index->size()) {
534       CompareValues(expected_.Get<NativeT>(*multi_index),
535                     actual_.Get<NativeT>(*multi_index),
536                     IndexUtil::MultidimensionalIndexToLinearIndex(
537                         actual_.shape(), *multi_index));
538     } else {
539       for (int64_t i = 0; i < expected_.shape().dimensions(dimension); ++i) {
540         (*multi_index)[dimension] = i;
541         CompareLiteralsSlow(dimension + 1, multi_index);
542       }
543     }
544   }
545 
546   // Returns an error message string with a detailed breakdown of the
547   // mismatches. Called after calling Run().
ErrorMessage()548   std::string ErrorMessage() {
549     std::string out;
550     int64_t element_count = ShapeUtil::ElementsIn(actual_.shape());
551 
552     auto percent_string = [](float a, float b) {
553       float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
554       return absl::StrFormat("%0.4f%%", pct);
555     };
556 
557     StrAppendFormat(
558         &out,
559         "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
560         "%g, rel bound %g\n",
561         num_mismatches_, percent_string(num_mismatches_, element_count),
562         ShapeUtil::HumanString(actual_.shape()),
563         ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
564     if (num_nan_mismatches_ > 0) {
565       StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
566     }
567     StrAppendFormat(&out, "Top relative error mismatches:\n");
568     for (auto it = top_rel_mismatches_.rbegin();
569          it != top_rel_mismatches_.rend(); ++it) {
570       StrAppend(&out, "  ", it->ToString(actual_.shape()), "\n");
571     }
572 
573     if (!detailed_message_) {
574       return out;
575     }
576 
577     StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
578     CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
579     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
580       const int64_t bucket_size = abs_value_buckets_[i].first;
581       const int64_t bucket_mismatches = abs_value_buckets_[i].second;
582       std::string mismatch_str =
583           bucket_mismatches > 0
584               ? absl::StrFormat(", mismatches %d", bucket_mismatches)
585               : "";
586       StrAppendFormat(&out, "  %-6g <= x < %-6g : %7d (%9s)%s\n",
587                       kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
588                       bucket_size, percent_string(bucket_size, element_count),
589                       mismatch_str);
590     }
591 
592     auto print_accum_buckets = [&](const std::string& header, int64_t total,
593                                    absl::Span<const int64_t> buckets) {
594       StrAppend(&out, header, ":\n");
595       StrAppendFormat(&out, "  <  %-6g : %7d (%s)\n", kErrorBucketBounds[0],
596                       total - buckets[0],
597                       percent_string(total - buckets[0], total));
598       CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
599       for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
600         StrAppendFormat(&out, "  >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
601                         buckets[i], percent_string(buckets[i], total));
602       }
603     };
604     StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
605                     error_.abs, num_abs_mismatches_,
606                     percent_string(num_abs_mismatches_, element_count));
607     print_accum_buckets(
608         "Relative error breakdown of elements exceeding abs error bound",
609         num_abs_mismatches_, rel_error_buckets_);
610     StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
611                     error_.rel, num_rel_mismatches_,
612                     percent_string(num_rel_mismatches_, element_count));
613     print_accum_buckets(
614         "Absolute error breakdown of elements exceeding rel error bound",
615         num_rel_mismatches_, abs_error_buckets_);
616     return out;
617   }
618 
619   // 'actual' and 'expected' literals being compared.
620   LiteralSlice expected_;
621   LiteralSlice actual_;
622 
623   // The shape index of the LiteralSlice that is being compared.
624   ShapeIndex shape_index_;
625 
626   // The error bounds of the comparison.
627   ErrorSpec error_;
628 
629   // Whether to include detailed breakdown of mismatches in the error message.
630   bool detailed_message_;
631 
632   // Callback to invoke on miscompare.
633   MiscompareCallback miscompare_callback_;
634 
635   // Number of element mismatches encountered so far.
636   int64_t num_mismatches_ = 0;
637 
638   // Number of elements with a nan mismatch.
639   int64_t num_nan_mismatches_ = 0;
640 
641   // Number of elements which exceed the absolute/relative error bound.
642   int64_t num_abs_mismatches_ = 0;
643   int64_t num_rel_mismatches_ = 0;
644 
645   // A Literal containing which elements did not match in the expected and
646   // actual literals. mismatches_ contains PREDs and is of the same sizes as
647   // the comparison literals.
648   Literal mismatches_;
649 
650   // The number of mismatches to report in the output, sorted by relative error
651   // magnitude.
652   static constexpr int64_t kTopRelativeErrorCount = 5;
653 
654   // The set of mismatches with the largest relative error. The size of this set
655   // is bounded by kTopRelativeErrorCount.
656   std::multiset<Mismatch> top_rel_mismatches_;
657 
658   // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
659   // bounds of these buckets. abs_value_buckets_ contains a pair for each
660   // bucket: the element count and failure count.
661   static constexpr std::array<float, 7> kAbsValueBucketBounds = {
662       0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
663   std::vector<std::pair<int64_t, int64_t>> abs_value_buckets_;
664 
665   // Buckets for relative and absolute errors. The relative error buckets only
666   // contains those elements which exceed the *absolute* error bound, and vice
667   // versa. This makes it easy to see the effect of adjusting the relative (or
668   // absolute) error bound on the success of the comparison. kErrorBucketBounds
669   // are the lower bounds of the buckets in both vectors. The error buckets are
670   // a cumulative distribution so an error value may appear in more than one
671   // bucket. For example an error value of 0.003 may appear in the buckets
672   // bounded by 0.01, 0.1, and 1.0.
673   static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
674                                                               0.01, 0.1, 1};
675   std::vector<int64_t> abs_error_buckets_;
676   std::vector<int64_t> rel_error_buckets_;
677 };
678 
679 template <typename NativeT>
680 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
681 template <typename NativeT>
682 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
683 
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const MiscompareCallback & miscompare_callback)684 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
685                    const ShapeIndex& shape_index,
686                    const MiscompareCallback& miscompare_callback) {
687   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
688 
689   Status result;
690   if (expected.shape().IsTuple()) {
691     ShapeIndex next_index = shape_index;
692     for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
693       next_index.push_back(i);
694       Status tuple_result =
695           EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
696                       next_index, miscompare_callback);
697       if (miscompare_callback) {
698         result.Update(tuple_result);
699       } else {
700         TF_RETURN_IF_ERROR(tuple_result);
701       }
702       next_index.pop_back();
703     }
704   } else {
705     std::vector<int64_t> multi_index(expected.shape().dimensions_size(), 0);
706     auto index = absl::MakeSpan(multi_index);
707 
708     Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
709                                                expected.shape().dimensions());
710     Literal miscompared(unequal_shape);
711     Literal* miscompared_ptr =
712         (miscompare_callback == nullptr ? nullptr : &miscompared);
713 
714     switch (expected.shape().element_type()) {
715       case PRED:
716         result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
717         break;
718       case S8:
719         result = Equal<int8_t>(expected, actual, index, 0, miscompared_ptr);
720         break;
721       case S16:
722         result = Equal<int16_t>(expected, actual, index, 0, miscompared_ptr);
723         break;
724       case S32:
725         result = Equal<int32_t>(expected, actual, index, 0, miscompared_ptr);
726         break;
727       case S64:
728         result = Equal<int64_t>(expected, actual, index, 0, miscompared_ptr);
729         break;
730       case U8:
731         result = Equal<uint8_t>(expected, actual, index, 0, miscompared_ptr);
732         break;
733       case U16:
734         result = Equal<uint16_t>(expected, actual, index, 0, miscompared_ptr);
735         break;
736       case U32:
737         result = Equal<uint32_t>(expected, actual, index, 0, miscompared_ptr);
738         break;
739       case U64:
740         result = Equal<uint64_t>(expected, actual, index, 0, miscompared_ptr);
741         break;
742       case BF16:
743         result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
744         break;
745       case F16:
746         result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
747         break;
748       case F32:
749         result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
750         break;
751       case F64:
752         result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
753         break;
754       case C64:
755         result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
756         break;
757       case C128:
758         result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
759         break;
760       case TOKEN:
761         // Tokens have no on-device representation and are trivially equal.
762         return OkStatus();
763       default:
764         LOG(FATAL) << "Unsupported primitive type: "
765                    << PrimitiveType_Name(expected.shape().element_type());
766     }
767 
768     if (!result.ok() && miscompare_callback) {
769       miscompare_callback(expected, actual, LiteralSlice(miscompared),
770                           shape_index, ErrorBuckets());
771     }
772   }
773 
774   return result;
775 }
776 
777 // Helper function for comparing two literals for nearness. Handles tuple-shapes
778 // via recursion. shape_index is the ShapeIndex of expected (or actual)
779 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const ErrorSpec & error,std::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)780 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
781                   const ShapeIndex& shape_index, const ErrorSpec& error,
782                   std::optional<bool> detailed_message,
783                   const MiscompareCallback& miscompare_callback) {
784   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
785 
786   if (expected.shape().IsTuple()) {
787     Status return_status;
788     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(expected.shape());
789          ++i) {
790       const auto expected_element = LiteralSlice(expected, {i});
791       const auto actual_element = LiteralSlice(actual, {i});
792       ShapeIndex element_index = shape_index;
793       element_index.push_back(i);
794       Status element_result =
795           NearHelper(expected_element, actual_element, element_index, error,
796                      detailed_message, miscompare_callback);
797       if (!element_result.ok()) {
798         element_result = InvalidArgument("Array at shape index %s, %s",
799                                          element_index.ToString(),
800                                          element_result.error_message());
801         if (return_status.ok()) {
802           return_status = element_result;
803         } else {
804           return_status =
805               AppendStatus(return_status, element_result.error_message());
806         }
807       }
808     }
809     if (!return_status.ok() && shape_index.empty()) {
810       // Emit a top-level error message containing the top-level shape in case
811       // of mismatch.
812       int64_t total_elements = RecursiveElementCount(actual.shape());
813       return_status =
814           InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
815                           ShapeUtil::HumanString(actual.shape()),
816                           total_elements, return_status.error_message());
817     }
818     return return_status;
819   }
820 
821   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
822       ShapeUtil::ElementIsComplex(expected.shape())) {
823     bool use_detailed_message = detailed_message.value_or(
824         ShapeUtil::ElementsIn(expected.shape()) >= 64);
825     switch (expected.shape().element_type()) {
826       case BF16:
827         return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
828                                                  error, use_detailed_message,
829                                                  miscompare_callback);
830         break;
831       case F16:
832         return NearComparator<half>::Compare(expected, actual, shape_index,
833                                              error, use_detailed_message,
834                                              miscompare_callback);
835         break;
836       case F32:
837         return NearComparator<float>::Compare(expected, actual, shape_index,
838                                               error, use_detailed_message,
839                                               miscompare_callback);
840         break;
841       case F64:
842         return NearComparator<double>::Compare(expected, actual, shape_index,
843                                                error, use_detailed_message,
844                                                miscompare_callback);
845         break;
846       case C64:
847         return NearComparator<complex64>::Compare(expected, actual, shape_index,
848                                                   error, use_detailed_message,
849                                                   miscompare_callback);
850         break;
851       case C128:
852         return NearComparator<complex128>::Compare(
853             expected, actual, shape_index, error, use_detailed_message,
854             miscompare_callback);
855         break;
856       default:
857         LOG(FATAL) << "Unsupported primitive type in near comparator: "
858                    << PrimitiveType_Name(expected.shape().element_type())
859                    << ". Must be floating-point type.";
860     }
861   }
862 
863   // Non-floating point, non-tuple literal.
864   return EqualHelper(expected, actual, shape_index, miscompare_callback);
865 }
866 
867 }  // namespace
868 
EqualShapes(const Shape & expected,const Shape & actual)869 Status EqualShapes(const Shape& expected, const Shape& actual) {
870   if (expected.element_type() != actual.element_type()) {
871     return InvalidArgument("element type mismatch, want: %s got %s",
872                            ShapeUtil::HumanString(expected),
873                            ShapeUtil::HumanString(actual));
874   }
875   if (expected.IsTuple()) {
876     if (ShapeUtil::TupleElementCount(expected) !=
877         ShapeUtil::TupleElementCount(actual)) {
878       return InvalidArgument(
879           "want tuple element count: %d got tuple element count: %d",
880           ShapeUtil::TupleElementCount(expected),
881           ShapeUtil::TupleElementCount(actual));
882     }
883     for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
884       Status result =
885           EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
886       if (!result.ok()) {
887         return AppendStatus(result, StrCat("mismatch in tuple index", i));
888       }
889     }
890   } else if (expected.IsArray()) {
891     if (expected.rank() != actual.rank()) {
892       return InvalidArgument("want rank of %s got rank of %s",
893                              ShapeUtil::HumanString(expected),
894                              ShapeUtil::HumanString(actual));
895     }
896     if (expected.element_type() != actual.element_type()) {
897       return InvalidArgument("mismatch in primitive type %s vs %s",
898                              PrimitiveType_Name(expected.element_type()),
899                              PrimitiveType_Name(actual.element_type()));
900     }
901     if (expected.dimensions_size() != actual.dimensions_size()) {
902       return InvalidArgument("want dimensions_size %d got dimensions_size %d",
903                              expected.dimensions_size(),
904                              actual.dimensions_size());
905     }
906     for (int i = 0; i < expected.dimensions_size(); ++i) {
907       if (expected.dimensions(i) != actual.dimensions(i)) {
908         return InvalidArgument(
909             "mismatch in dimension #%d expected: %s actual: %s", i,
910             ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
911       }
912     }
913   }
914   // Non-array, non-tuple shapes are trivially equivalent.
915   return OkStatus();
916 }
917 
918 namespace {
919 
920 // If result is an error, extend the error message with the expected and actual
921 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)922 Status EmitLiteralsInErrorMessage(const Status& result,
923                                   const LiteralSlice& expected,
924                                   const LiteralSlice& actual) {
925   if (result.ok()) {
926     return result;
927   }
928   return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
929                          result.error_message(), ToStringTruncated(expected),
930                          ToStringTruncated(actual));
931 }
932 
933 }  // namespace
934 
Equal(const LiteralSlice & expected,const LiteralSlice & actual)935 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
936   VLOG(1) << "expected:";
937   XLA_VLOG_LINES(1, expected.ToString());
938   VLOG(1) << "actual:";
939   XLA_VLOG_LINES(1, actual.ToString());
940   Status result = EqualHelper(expected, actual, {}, nullptr);
941   return EmitLiteralsInErrorMessage(result, expected, actual);
942 }
943 
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,std::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)944 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
945             const ErrorSpec& error, std::optional<bool> detailed_message,
946             const MiscompareCallback& miscompare_callback) {
947   VLOG(1) << "Expected literal:";
948   XLA_VLOG_LINES(1, expected.ToString());
949   VLOG(1) << "Actual literal:";
950   XLA_VLOG_LINES(1, actual.ToString());
951   Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
952                              detailed_message, miscompare_callback);
953   return EmitLiteralsInErrorMessage(result, expected, actual);
954 }
955 
ToStringTruncated(const LiteralSlice & literal)956 std::string ToStringTruncated(const LiteralSlice& literal) {
957   return RecursiveElementCount(literal.shape()) < 1000
958              ? literal.ToString()
959              : "[TRUNCATED, Literal with more than 1000 values]";
960 }
961 
962 }  // namespace literal_comparison
963 }  // namespace xla
964