1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/comparison_util.h"
17
18 #include <optional>
19 #include <string>
20
21 #include "absl/base/attributes.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29 namespace {
30
31 // Verifies that this is a valid Comparison: (1) not a partial ordering on
32 // integers, and (2) a valid PrimitiveType.
IsValidComparison(xla::PrimitiveType type,Comparison::Order order)33 bool IsValidComparison(xla::PrimitiveType type, Comparison::Order order) {
34 switch (type) {
35 case F16:
36 case F32:
37 case BF16:
38 case F64:
39 case C64:
40 case C128:
41 return true;
42 case S8:
43 case S16:
44 case S32:
45 case S64:
46 case PRED:
47 case U8:
48 case U16:
49 case U32:
50 case U64:
51 return order == Comparison::Order::kTotal;
52 case TUPLE:
53 case OPAQUE_TYPE:
54 case TOKEN:
55 case PRIMITIVE_TYPE_INVALID:
56 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
57 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
58 return false;
59 }
60 }
61
62 // Returns the X32 primitive type for each Type.
DefaultPrimitiveType(Comparison::Type type)63 PrimitiveType DefaultPrimitiveType(Comparison::Type type) {
64 switch (type) {
65 case Comparison::Type::kFloat:
66 case Comparison::Type::kFloatTotalOrder:
67 return PrimitiveType::F32;
68 case Comparison::Type::kSigned:
69 return PrimitiveType::S32;
70 case Comparison::Type::kUnsigned:
71 return PrimitiveType::U32;
72 }
73 }
74
75 // Returns the default ordering for each Comparison::Type.
DefaultOrdering(Comparison::Type type)76 Comparison::Order DefaultOrdering(Comparison::Type type) {
77 switch (type) {
78 case Comparison::Type::kFloat:
79 return Comparison::Order::kPartial;
80 case Comparison::Type::kFloatTotalOrder:
81 case Comparison::Type::kSigned:
82 case Comparison::Type::kUnsigned:
83 return Comparison::Order::kTotal;
84 }
85 }
86
87 // Returns the expected ordering for each primitive type.
DefaultOrdering(PrimitiveType type)88 Comparison::Order DefaultOrdering(PrimitiveType type) {
89 switch (type) {
90 case S8:
91 case S16:
92 case S32:
93 case S64:
94 case PRED:
95 case U8:
96 case U16:
97 case U32:
98 case U64:
99 return Comparison::Order::kTotal;
100 case BF16:
101 case F16:
102 case F32:
103 case F64:
104 case C64:
105 case C128:
106 return Comparison::Order::kPartial;
107 default:
108 LOG(FATAL) << "Unsupported type: " << PrimitiveType_Name(type);
109 }
110 }
111
112 // Returns the converse of `direction`.
Converse(Comparison::Direction direction)113 Comparison::Direction Converse(Comparison::Direction direction) {
114 switch (direction) {
115 case Comparison::Direction::kEq:
116 return Comparison::Direction::kEq;
117 case Comparison::Direction::kNe:
118 return Comparison::Direction::kNe;
119 case Comparison::Direction::kGe:
120 return Comparison::Direction::kLe;
121 case Comparison::Direction::kGt:
122 return Comparison::Direction::kLt;
123 case Comparison::Direction::kLe:
124 return Comparison::Direction::kGe;
125 case Comparison::Direction::kLt:
126 return Comparison::Direction::kGt;
127 }
128 }
129
130 // Returns the inverse of `direction`.
Inverse(Comparison::Direction direction)131 Comparison::Direction Inverse(Comparison::Direction direction) {
132 switch (direction) {
133 case Comparison::Direction::kEq:
134 return Comparison::Direction::kNe;
135 case Comparison::Direction::kNe:
136 return Comparison::Direction::kEq;
137 case Comparison::Direction::kGe:
138 return Comparison::Direction::kLt;
139 case Comparison::Direction::kGt:
140 return Comparison::Direction::kLe;
141 case Comparison::Direction::kLe:
142 return Comparison::Direction::kGt;
143 case Comparison::Direction::kLt:
144 return Comparison::Direction::kGe;
145 }
146 }
147
148 } // namespace
149
ComparisonDirectionToString(Comparison::Direction direction)150 std::string ComparisonDirectionToString(Comparison::Direction direction) {
151 switch (direction) {
152 case Comparison::Direction::kEq:
153 return "EQ";
154 case Comparison::Direction::kNe:
155 return "NE";
156 case Comparison::Direction::kGe:
157 return "GE";
158 case Comparison::Direction::kGt:
159 return "GT";
160 case Comparison::Direction::kLe:
161 return "LE";
162 case Comparison::Direction::kLt:
163 return "LT";
164 default:
165 LOG(FATAL) << "Attempted to print uninitialized comparison direction";
166 }
167 }
168
ComparisonTypeToString(Comparison::Type type)169 std::string ComparisonTypeToString(Comparison::Type type) {
170 switch (type) {
171 case Comparison::Type::kFloat:
172 return "FLOAT";
173 case Comparison::Type::kFloatTotalOrder:
174 return "TOTALORDER";
175 case Comparison::Type::kSigned:
176 return "SIGNED";
177 case Comparison::Type::kUnsigned:
178 return "UNSIGNED";
179 }
180 }
181
ComparisonPrimitiveTypeToString(PrimitiveType type)182 std::string ComparisonPrimitiveTypeToString(PrimitiveType type) {
183 return PrimitiveType_Name(type);
184 }
185
ComparisonOrderToString(Comparison::Order order)186 std::string ComparisonOrderToString(Comparison::Order order) {
187 switch (order) {
188 case Comparison::Order::kPartial:
189 return "PARTIALORDER";
190 case Comparison::Order::kTotal:
191 return "TOTALORDER";
192 }
193 }
194
StringToComparisonDirection(absl::string_view direction)195 StatusOr<Comparison::Direction> StringToComparisonDirection(
196 absl::string_view direction) {
197 static auto* map =
198 new absl::flat_hash_map<std::string, Comparison::Direction>({
199 {"EQ", Comparison::Direction::kEq},
200 {"NE", Comparison::Direction::kNe},
201 {"GE", Comparison::Direction::kGe},
202 {"GT", Comparison::Direction::kGt},
203 {"LE", Comparison::Direction::kLe},
204 {"LT", Comparison::Direction::kLt},
205 });
206 auto it = map->find(direction);
207 if (it == map->end()) {
208 return InvalidArgument("Unknown comparison direction: %s", direction);
209 }
210 return it->second;
211 }
212
StringToComparisonOrder(absl::string_view order)213 StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order) {
214 static auto* map = new absl::flat_hash_map<std::string, Comparison::Order>({
215 {"TOTALORDER", Comparison::Order::kTotal},
216 {"PARTIALORDER", Comparison::Order::kPartial},
217 });
218 auto it = map->find(order);
219 if (it == map->end()) {
220 return InvalidArgument("Unknown comparison type: %s", order);
221 }
222 return it->second;
223 }
224
StringToComparisonType(absl::string_view comparison)225 StatusOr<Comparison::Type> StringToComparisonType(
226 absl::string_view comparison) {
227 static auto* map = new absl::flat_hash_map<std::string, Comparison::Type>({
228 {"FLOAT", Comparison::Type::kFloat},
229 {"TOTALORDER", Comparison::Type::kFloatTotalOrder},
230 {"SIGNED", Comparison::Type::kSigned},
231 {"UNSIGNED", Comparison::Type::kUnsigned},
232 });
233 auto it = map->find(comparison);
234 if (it == map->end()) {
235 return InvalidArgument("Unknown comparison type: %s", comparison);
236 }
237 return it->second;
238 }
239
DefaultComparisonType(PrimitiveType type)240 Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
241 switch (type) {
242 case S8:
243 case S16:
244 case S32:
245 case S64:
246 return Type::kSigned;
247 case PRED:
248 case U8:
249 case U16:
250 case U32:
251 case U64:
252 return Type::kUnsigned;
253 case F16:
254 case F32:
255 case BF16:
256 case F64:
257 case C64:
258 case C128:
259 return Type::kFloat;
260 default:
261 LOG(FATAL) << "Unexpected: " << PrimitiveType_Name(type);
262 }
263 }
264
Comparison(Direction dir,PrimitiveType type,Order order)265 Comparison::Comparison(Direction dir, PrimitiveType type, Order order)
266 : dir_(dir),
267 primitive_type_(type),
268 order_(order),
269 type_(DefaultComparisonType(type)) {
270 CHECK(IsValidComparison(primitive_type_, order_));
271 }
272
Comparison(Direction dir,PrimitiveType type)273 Comparison::Comparison(Direction dir, PrimitiveType type)
274 : dir_(dir),
275 primitive_type_(type),
276 order_(DefaultOrdering(type)),
277 type_(DefaultComparisonType(type)) {
278 CHECK(IsValidComparison(primitive_type_, order_));
279 }
280
Comparison(Direction dir,Type type)281 Comparison::Comparison(Direction dir, Type type)
282 : dir_(dir),
283 primitive_type_(DefaultPrimitiveType(type)),
284 order_(DefaultOrdering(type)),
285 type_(type) {
286 CHECK(IsValidComparison(primitive_type_, order_));
287 }
288
Converse() const289 Comparison Comparison::Converse() const {
290 return Comparison(xla::Converse(dir_), primitive_type_, order_);
291 }
292
Inverse() const293 std::optional<Comparison> Comparison::Inverse() const {
294 if (IsPartialOrder()) {
295 // We assume comparisons don't have inverses unless they are total order,
296 // e.g., a partial order floating point comparison can return true if one
297 // operand is NaN.
298 return std::nullopt;
299 }
300 switch (primitive_type_) {
301 case F16:
302 case F32:
303 case BF16:
304 case F64:
305 case C64:
306 case C128:
307 case S8:
308 case S16:
309 case S32:
310 case S64:
311 case PRED:
312 case U8:
313 case U16:
314 case U32:
315 case U64:
316 return Comparison(xla::Inverse(dir_), primitive_type_, order_);
317 case TUPLE:
318 case OPAQUE_TYPE:
319 case TOKEN:
320 case PRIMITIVE_TYPE_INVALID:
321 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
322 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
323 return std::nullopt;
324 }
325 }
326
IsReflexive() const327 bool Comparison::IsReflexive() const {
328 switch (dir_) {
329 case Direction::kEq:
330 case Direction::kGe:
331 case Direction::kLe:
332 return IsTotalOrder();
333 case Direction::kNe:
334 case Direction::kGt:
335 case Direction::kLt:
336 return false;
337 }
338 }
339
IsAntireflexive() const340 bool Comparison::IsAntireflexive() const {
341 switch (dir_) {
342 case Direction::kNe:
343 return IsTotalOrder();
344 case Direction::kGt:
345 case Direction::kLt:
346 return true;
347 case Direction::kEq:
348 case Direction::kGe:
349 case Direction::kLe:
350 return false;
351 }
352 }
353
ToString(std::string prefix1,std::string prefix2,std::string prefix3) const354 std::string Comparison::ToString(std::string prefix1, std::string prefix2,
355 std::string prefix3) const {
356 return absl::StrCat(prefix1, ComparisonDirectionToString(dir_), prefix2,
357 ComparisonPrimitiveTypeToString(primitive_type_), prefix3,
358 ComparisonOrderToString(order_));
359 }
360 } // namespace xla
361