1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H
17 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
24 #include "mlir/Dialect/Complex/IR/Complex.h"
25 #include "mlir/Dialect/Math/IR/Math.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "mlir/IR/TypeUtilities.h"
31 
32 namespace mlir {
33 namespace mhlo {
34 namespace impl {
35 
36 // A struct to map MhloBinaryOpTy type to the corresponding floating-point and
37 // integer scalar operation types.
38 template <typename MhloBinaryOpTy>
39 struct MhloToScalarOp {
40   using FOp = void;
41   using IOp = void;
42   using UOp = void;
43   using COp = void;
44 };
45 
46 template <>
47 struct MhloToScalarOp<mhlo::AddOp> {
48   using FOp = ::mlir::arith::AddFOp;
49   using IOp = ::mlir::arith::AddIOp;
50   using UOp = ::mlir::arith::AddIOp;
51   using COp = ::mlir::complex::AddOp;
52 };
53 template <>
54 struct MhloToScalarOp<mhlo::AndOp> {
55   using IOp = ::mlir::arith::AndIOp;
56   using UOp = ::mlir::arith::AndIOp;
57 };
58 template <>
59 struct MhloToScalarOp<mhlo::CompareOp> {
60   using FOp = ::mlir::arith::CmpFOp;
61   using IOp = ::mlir::arith::CmpIOp;
62   using UOp = ::mlir::arith::CmpIOp;
63 };
64 template <>
65 struct MhloToScalarOp<mhlo::CeilOp> {
66   using FOp = ::mlir::math::CeilOp;
67 };
68 template <>
69 struct MhloToScalarOp<mhlo::ClzOp> {
70   using IOp = ::mlir::math::CountLeadingZerosOp;
71   using UOp = ::mlir::math::CountLeadingZerosOp;
72 };
73 template <>
74 struct MhloToScalarOp<mhlo::CosineOp> {
75   using FOp = ::mlir::math::CosOp;
76   using COp = ::mlir::complex::CosOp;
77 };
78 template <>
79 struct MhloToScalarOp<mhlo::ExpOp> {
80   using FOp = ::mlir::math::ExpOp;
81   using COp = ::mlir::complex::ExpOp;
82 };
83 template <>
84 struct MhloToScalarOp<mhlo::Expm1Op> {
85   using FOp = ::mlir::math::ExpM1Op;
86   using COp = ::mlir::complex::Expm1Op;
87 };
88 template <>
89 struct MhloToScalarOp<mhlo::FloorOp> {
90   using FOp = ::mlir::math::FloorOp;
91 };
92 template <>
93 struct MhloToScalarOp<mhlo::MaxOp> {
94   using FOp = ::mlir::arith::MaxFOp;
95   using IOp = ::mlir::arith::MaxSIOp;
96   using UOp = ::mlir::arith::MaxUIOp;
97 };
98 template <>
99 struct MhloToScalarOp<mhlo::MinOp> {
100   using FOp = ::mlir::arith::MinFOp;
101   using IOp = ::mlir::arith::MinSIOp;
102   using UOp = ::mlir::arith::MinUIOp;
103 };
104 template <>
105 struct MhloToScalarOp<mhlo::LogOp> {
106   using FOp = ::mlir::math::LogOp;
107   using COp = ::mlir::complex::LogOp;
108 };
109 template <>
110 struct MhloToScalarOp<mhlo::Log1pOp> {
111   using FOp = ::mlir::math::Log1pOp;
112   using COp = ::mlir::complex::Log1pOp;
113 };
114 template <>
115 struct MhloToScalarOp<mhlo::MulOp> {
116   using FOp = ::mlir::arith::MulFOp;
117   using IOp = ::mlir::arith::MulIOp;
118   using UOp = ::mlir::arith::MulIOp;
119   using COp = ::mlir::complex::MulOp;
120 };
121 template <>
122 struct MhloToScalarOp<mhlo::OrOp> {
123   using IOp = ::mlir::arith::OrIOp;
124   using UOp = ::mlir::arith::OrIOp;
125 };
126 template <>
127 struct MhloToScalarOp<mhlo::PopulationCountOp> {
128   using IOp = ::mlir::math::CtPopOp;
129   using UOp = ::mlir::math::CtPopOp;
130 };
131 template <>
132 struct MhloToScalarOp<mhlo::RsqrtOp> {
133   using FOp = ::mlir::math::RsqrtOp;
134   using COp = ::mlir::complex::RsqrtOp;
135 };
136 template <>
137 struct MhloToScalarOp<mhlo::RoundOp> {
138   using FOp = ::mlir::math::RoundOp;
139 };
140 template <>
141 struct MhloToScalarOp<mhlo::SubtractOp> {
142   using FOp = ::mlir::arith::SubFOp;
143   using IOp = ::mlir::arith::SubIOp;
144   using UOp = ::mlir::arith::SubIOp;
145   using COp = ::mlir::complex::SubOp;
146 };
147 template <>
148 struct MhloToScalarOp<mhlo::SqrtOp> {
149   using FOp = ::mlir::math::SqrtOp;
150   using COp = ::mlir::complex::SqrtOp;
151 };
152 template <>
153 struct MhloToScalarOp<mhlo::SineOp> {
154   using FOp = ::mlir::math::SinOp;
155   using COp = ::mlir::complex::SinOp;
156 };
157 template <>
158 struct MhloToScalarOp<mhlo::ShiftLeftOp> {
159   using IOp = ::mlir::arith::ShLIOp;
160   using UOp = ::mlir::arith::ShLIOp;
161 };
162 template <>
163 struct MhloToScalarOp<mhlo::ShiftRightArithmeticOp> {
164   using IOp = ::mlir::arith::ShRSIOp;
165   using UOp = ::mlir::arith::ShRSIOp;
166 };
167 template <>
168 struct MhloToScalarOp<mhlo::ShiftRightLogicalOp> {
169   using IOp = ::mlir::arith::ShRUIOp;
170   using UOp = ::mlir::arith::ShRUIOp;
171 };
172 template <>
173 struct MhloToScalarOp<mhlo::Atan2Op> {
174   using FOp = ::mlir::math::Atan2Op;
175   using COp = ::mlir::complex::Atan2Op;
176 };
177 template <>
178 struct MhloToScalarOp<mhlo::TanhOp> {
179   using FOp = ::mlir::math::TanhOp;
180   using COp = ::mlir::complex::TanhOp;
181 };
182 template <>
183 struct MhloToScalarOp<mhlo::XorOp> {
184   using IOp = ::mlir::arith::XOrIOp;
185   using UOp = ::mlir::arith::XOrIOp;
186 };
187 
188 // Alias for the map from MHLO binary op type to STD floating-point op type.
189 template <typename MhloOp>
190 using ScalarFOp = typename MhloToScalarOp<MhloOp>::FOp;
191 // Alias for the map from MHLO binary op type to STD signed integer op type.
192 template <typename MhloOp>
193 using ScalarIOp = typename MhloToScalarOp<MhloOp>::IOp;
194 // Alias for the map from MHLO binary op type to STD unsigned integer op type.
195 template <typename MhloOp>
196 using ScalarUOp = typename MhloToScalarOp<MhloOp>::UOp;
197 // Alias for the map from MHLO binary op type to STD complex op type.
198 template <typename MhloOp>
199 using ScalarCOp = typename MhloToScalarOp<MhloOp>::COp;
200 
201 template <typename... Args>
202 struct MapMhloOpToScalarOpImpl {
203   Value operator()(Location /*loc*/, ArrayRef<Type> /*ResultTypes*/,
204                    ArrayRef<Type> /*argTypes*/, ValueRange /*args*/,
205                    OpBuilder* /*b*/) {
206     return nullptr;
207   }
208 };
209 
210 template <typename StdScalarOp>
211 struct MapMhloOpToScalarOpImpl<StdScalarOp> {
212   Value operator()(Location loc, ArrayRef<Type> resultTypes,
213                    ArrayRef<Type> /*argTypes*/, ValueRange args, OpBuilder* b) {
214     return b->template create<StdScalarOp>(loc, resultTypes, args, mlir::None);
215   }
216 };
217 
218 template <typename SupportedType, typename StdScalarOp, typename... Args>
219 struct MapMhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> {
220   Value operator()(Location loc, ArrayRef<Type> resultTypes,
221                    ArrayRef<Type> argTypes, ValueRange args, OpBuilder* b) {
222     Type elementType = getElementTypeOrSelf(argTypes.front());
223     if (SupportedType{}(elementType)) {
224       return b->template create<StdScalarOp>(loc, resultTypes, args,
225                                              mlir::None);
226     }
227     return MapMhloOpToScalarOpImpl<Args...>{}(loc, resultTypes, argTypes, args,
228                                               b);
229   }
230 };
231 
232 template <typename SupportedType, typename... Args>
233 struct MapMhloOpToScalarOpImpl<SupportedType, void, Args...> {
234   Value operator()(Location loc, ArrayRef<Type> resultTypes,
235                    ArrayRef<Type> argTypes, ValueRange args, OpBuilder* b) {
236     return MapMhloOpToScalarOpImpl<Args...>{}(loc, resultTypes, argTypes, args,
237                                               b);
238   }
239 };
240 
241 struct IsAnyIntegerType {
242   bool operator()(Type t) { return t.isa<IntegerType>(); }
243 };
244 
245 struct IsSignedIntegerType {
246   bool operator()(Type t) {
247     // Pretend that signless is signed. This will change eventually.
248     return t.isa<IntegerType>() && !t.isUnsignedInteger() &&
249            !t.isSignlessInteger(1);
250   }
251 };
252 
253 struct IsUnsignedIntegerType {
254   bool operator()(Type t) {
255     return t.isUnsignedInteger() || t.isSignlessInteger(1);
256   }
257 };
258 
259 struct IsFloatType {
260   bool operator()(Type t) { return t.isa<FloatType>(); }
261 };
262 
263 struct IsComplexType {
264   bool operator()(Type t) { return t.isa<ComplexType>(); }
265 };
266 
267 template <template <typename T> class MapTy, typename OpTy,
268           typename PredTy = llvm::is_detected<MapTy, OpTy>>
269 struct MapableIf {
270   using type = void;
271 };
272 template <template <typename T> class MapTy, typename OpTy>
273 struct MapableIf<MapTy, OpTy, std::true_type> {
274   using type = MapTy<OpTy>;
275 };
276 
277 // Inserts the computation that corresponds to the body of the loop for lowered
278 // MHLO unary/binary op. Returns the value for the result.
279 template <typename MhloOpTy>
280 inline Value mapMhloOpToStdScalarOp(Location loc, ArrayRef<Type> resultTypes,
281                                     ArrayRef<Type> argTypes, ValueRange args,
282                                     OpBuilder* b) {
283   using ScalarIOpOrVoid = typename MapableIf<ScalarIOp, MhloOpTy>::type;
284   using ScalarUOpOrVoid = typename MapableIf<ScalarUOp, MhloOpTy>::type;
285   using ScalarFOpOrVoid = typename MapableIf<ScalarFOp, MhloOpTy>::type;
286   using ScalarCOpOrVoid = typename MapableIf<ScalarCOp, MhloOpTy>::type;
287   return MapMhloOpToScalarOpImpl<IsSignedIntegerType, ScalarIOpOrVoid,
288                                  IsUnsignedIntegerType, ScalarUOpOrVoid,
289                                  IsFloatType, ScalarFOpOrVoid, IsComplexType,
290                                  ScalarCOpOrVoid>{}(loc, resultTypes, argTypes,
291                                                     args, b);
292 }
293 
294 template <>
295 inline Value mapMhloOpToStdScalarOp<mhlo::AbsOp>(Location loc,
296                                                  ArrayRef<Type> resultTypes,
297                                                  ArrayRef<Type> argTypes,
298                                                  ValueRange args,
299                                                  OpBuilder* b) {
300   Type elementType = getElementTypeOrSelf(argTypes.front());
301   if (elementType.isa<FloatType>()) {
302     return MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::math::AbsFOp>{}(
303         loc, resultTypes, argTypes, args, b);
304   }
305   if (elementType.isa<ComplexType>()) {
306     return MapMhloOpToScalarOpImpl<IsComplexType, ::mlir::complex::AbsOp>{}(
307         loc, resultTypes, argTypes, args, b);
308   }
309   if (elementType.isSignlessInteger() || elementType.isSignedInteger()) {
310     // lmhlo.abs(x, result) ->  result = select((x > 0), x, sub(0, x))
311     Value lhs = args[0];
312     Value zeroIntval =
313         b->create<arith::ConstantOp>(loc, b->getZeroAttr(lhs.getType()));
314     auto lhsGtZero = b->create<ScalarIOp<CompareOp>>(
315         loc, arith::CmpIPredicate::sge, lhs, zeroIntval);
316     auto negVal = b->create<ScalarIOp<mhlo::SubtractOp>>(loc, zeroIntval, lhs);
317     return b->create<::mlir::arith::SelectOp>(loc, lhsGtZero, lhs, negVal);
318   }
319   return nullptr;
320 }
321 
322 // Return a constant for v of type t, splat if t is a vector type.
323 inline Value getConstantOrSplat(OpBuilder* b, Location loc, Type t,
324                                 Attribute v) {
325   if (VectorType vecType = t.dyn_cast<VectorType>()) {
326     v = SplatElementsAttr::get(vecType, v);
327   }
328   return b->create<arith::ConstantOp>(loc, t, v);
329 }
330 
331 template <>
332 inline Value mapMhloOpToStdScalarOp<mhlo::CbrtOp>(Location loc,
333                                                   ArrayRef<Type> resultTypes,
334                                                   ArrayRef<Type> argTypes,
335                                                   ValueRange args,
336                                                   OpBuilder* b) {
337   mhlo::CbrtOp::Adaptor adaptor(args);
338   Type elementType = getElementTypeOrSelf(argTypes.front());
339   if (auto floatType = elementType.dyn_cast<FloatType>()) {
340     // Convert cbrt(x) to copysign(cbrt(abs(x), 1.0 / 3.0), x).
341     // This is to allow cbrt using pow while still handling negative numbers. It
342     // should match most cbrt intrinsics.
343     Value abs = b->create<mlir::math::AbsFOp>(loc, adaptor.operand());
344     Value third = b->create<arith::ConstantOp>(
345         loc, b->getFloatAttr(floatType, 1.0 / 3.0));
346     Value pow = b->create<mlir::math::PowFOp>(loc, resultTypes[0], abs, third);
347     return b->create<mlir::math::CopySignOp>(loc, floatType, pow,
348                                              adaptor.operand());
349   }
350   return nullptr;
351 }
352 
353 template <typename PredicateType>
354 inline Optional<PredicateType> getCmpPredicate(mhlo::ComparisonDirection,
355                                                bool) {
356   return llvm::None;
357 }
358 
359 template <>
360 inline Optional<arith::CmpFPredicate> getCmpPredicate<arith::CmpFPredicate>(
361     mhlo::ComparisonDirection comparisonDirection, bool isSigned) {
362   assert(isSigned && "cannot have an unsigned float!");
363   return llvm::StringSwitch<Optional<arith::CmpFPredicate>>(
364              stringifyComparisonDirection(comparisonDirection))
365       .Case("EQ", arith::CmpFPredicate::OEQ)
366       .Case("NE", arith::CmpFPredicate::UNE)
367       .Case("GE", arith::CmpFPredicate::OGE)
368       .Case("GT", arith::CmpFPredicate::OGT)
369       .Case("LE", arith::CmpFPredicate::OLE)
370       .Case("LT", arith::CmpFPredicate::OLT)
371       .Default(llvm::None);
372 }
373 
374 template <>
375 inline Optional<arith::CmpIPredicate> getCmpPredicate<arith::CmpIPredicate>(
376     mhlo::ComparisonDirection comparisonDirection, bool isSigned) {
377   return llvm::StringSwitch<Optional<arith::CmpIPredicate>>(
378              stringifyComparisonDirection(comparisonDirection))
379       .Case("EQ", arith::CmpIPredicate::eq)
380       .Case("NE", arith::CmpIPredicate::ne)
381       .Case("GE",
382             isSigned ? arith::CmpIPredicate::sge : arith::CmpIPredicate::uge)
383       .Case("GT",
384             isSigned ? arith::CmpIPredicate::sgt : arith::CmpIPredicate::ugt)
385       .Case("LE",
386             isSigned ? arith::CmpIPredicate::sle : arith::CmpIPredicate::ule)
387       .Case("LT",
388             isSigned ? arith::CmpIPredicate::slt : arith::CmpIPredicate::ult)
389       .Default(llvm::None);
390 }
391 
392 inline Value mapCompareOpToStdScalarOp(Location loc,
393                                        ComparisonDirection comparisonDirection,
394                                        ArrayRef<Type> /*ResultTypes*/,
395                                        ArrayRef<Type> argTypes, ValueRange args,
396                                        OpBuilder* b) {
397   const auto& lhs = args[0];
398   const auto& rhs = args[1];
399   Type elementType = getElementTypeOrSelf(argTypes.front());
400   if (elementType.isa<IntegerType>()) {
401     bool isUnsigned = IsUnsignedIntegerType{}(elementType);
402     Optional<arith::CmpIPredicate> predicate =
403         getCmpPredicate<arith::CmpIPredicate>(comparisonDirection, !isUnsigned);
404     assert(predicate.has_value() && "expected valid comparison direction");
405     return b->create<ScalarIOp<mhlo::CompareOp>>(loc, predicate.value(), lhs,
406                                                  rhs);
407   }
408   if (elementType.isa<FloatType>()) {
409     Optional<arith::CmpFPredicate> predicate =
410         getCmpPredicate<arith::CmpFPredicate>(comparisonDirection,
411                                               /*is_signed=*/true);
412     assert(predicate.has_value() && "expected valid comparison direction");
413     return b->create<ScalarFOp<mhlo::CompareOp>>(loc, predicate.value(), lhs,
414                                                  rhs);
415   }
416   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
417     if (complexType.getElementType().isa<FloatType>()) {
418       if (comparisonDirection == ComparisonDirection::EQ) {
419         return b->create<complex::EqualOp>(loc, lhs, rhs);
420       }
421       if (comparisonDirection == ComparisonDirection::NE) {
422         return b->create<complex::NotEqualOp>(loc, lhs, rhs);
423       }
424     }
425   }
426   return nullptr;
427 }
428 
429 inline Value mapReducePrecisionOpToStdScalarOp(
430     Location loc, ArrayRef<Type> argTypes, ValueRange args, OpBuilder* builder,
431     int destExponentBits, int destMantissaBits) {
432   using llvm::APInt;
433   mlir::ImplicitLocOpBuilder b(loc, *builder);
434 
435   // Integer and float types for casting and constant generation.
436   auto floatType =
437       argTypes.front().cast<TensorType>().getElementType().cast<FloatType>();
438   int64_t nbits = floatType.getWidth();
439   auto intType = mlir::IntegerType::get(loc.getContext(), floatType.getWidth());
440 
441   Value xAsInt = b.create<arith::BitcastOp>(intType, args[0]);
442 
443   // SignificandWidth includes the implicit extra bit.
444   auto srcMantissaBits = floatType.getFPMantissaWidth() - 1;
445   int srcExponentBits = nbits - 1 - srcMantissaBits;
446 
447   // Clear the sign bit, it does not participate in rounding and we will restore
448   // it later.
449   APInt signBitMask(nbits, 1);
450   signBitMask <<= nbits - 1;
451 
452   APInt expBitsMask(nbits, 1);
453   expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits;
454 
455   if (destMantissaBits < static_cast<int>(srcMantissaBits)) {
456     // Last remaining mantissa bit.
457     APInt lastMantissaBitMask(nbits, 1);
458     lastMantissaBitMask <<= srcMantissaBits - destMantissaBits;
459 
460     // Compute rounding bias for round-to-nearest with ties to even.  This is
461     // equal to a base value of 0111... plus one bit if the last remaining
462     // mantissa bit is 1.
463     APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1;
464 
465     Value mantissaDiff = b.create<arith::ConstantIntOp>(
466         srcMantissaBits - destMantissaBits, intType);
467     Value highestMantissaMaskVal = b.create<arith::ConstantIntOp>(
468         lastMantissaBitMask.getZExtValue(), intType);
469     Value baseRoundingBiasVal = b.create<arith::ConstantIntOp>(
470         baseRoundingBias.getZExtValue(), intType);
471     Value xLastMantissaBit = b.create<arith::ShRUIOp>(
472         b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff);
473     Value xRoundingBias =
474         b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal);
475 
476     // Add rounding bias, and mask out truncated bits.  Note that the case
477     // where adding the rounding bias overflows into the exponent bits is
478     // correct; the non-masked mantissa bits will all be zero, and the
479     // exponent will be incremented by one.
480     APInt truncationMask = ~(lastMantissaBitMask - 1);
481     Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias);
482     xRounded = b.create<arith::AndIOp>(
483         xRounded,
484         b.create<arith::ConstantIntOp>(truncationMask.getZExtValue(), intType)
485             .getResult());
486     xAsInt = xRounded;
487   }
488 
489   if (destExponentBits < srcExponentBits) {
490     // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
491     // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
492     // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
493     // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
494     // exponent (corresponding to 0.0f).
495     //
496     // Thus, the f32 exponent corresponding to the highest non-infinite
497     // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
498     // exponent corresponding to the lowest exponent for a bit size of n is
499     // (2^7-1) - 2^(n-1)-1.
500     //
501     // Note that we have already checked that exponents_bits >= 1.
502     APInt exponentBias(nbits, 1);
503     exponentBias = (exponentBias << (srcExponentBits - 1)) - 1;
504 
505     APInt reducedExponentBias(nbits, 1);
506     reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1;
507 
508     APInt reducedMaxExponent = exponentBias + reducedExponentBias;
509     APInt reducedMinExponent = exponentBias - reducedExponentBias;
510 
511     // Do we overflow or underflow?
512     Value xExponent = b.create<arith::AndIOp>(
513         xAsInt,
514         b.create<arith::ConstantIntOp>(expBitsMask.getZExtValue(), intType)
515             .getResult());
516     Value xOverflows = b.create<arith::CmpIOp>(
517         arith::CmpIPredicate::ugt, xExponent,
518         b.create<arith::ConstantIntOp>(
519              (reducedMaxExponent << srcMantissaBits).getZExtValue(), intType)
520             .getResult());
521     Value xUnderflows = b.create<arith::CmpIOp>(
522         arith::CmpIPredicate::ule, xExponent,
523         b.create<arith::ConstantIntOp>(
524              (reducedMinExponent << srcMantissaBits).getZExtValue(), intType)
525             .getResult());
526 
527     // Compute appropriately-signed values of zero and infinity.
528     Value xSignedZero = b.create<arith::AndIOp>(
529         xAsInt,
530         b.create<arith::ConstantIntOp>(signBitMask.getZExtValue(), intType)
531             .getResult());
532     Value xSignedInf = b.create<arith::OrIOp>(
533         xSignedZero,
534         b.create<arith::ConstantIntOp>(expBitsMask.getZExtValue(), intType)
535             .getResult());
536 
537     // Force to zero or infinity if overflow or underflow.  (Note that this
538     // truncates all denormal values to zero, rather than rounding them.)
539     xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt);
540     xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt);
541   }
542 
543   return b.create<arith::BitcastOp>(floatType, xAsInt);
544 }
545 
546 template <>
547 inline Value mapMhloOpToStdScalarOp<mhlo::CopyOp>(
548     Location /*loc*/, ArrayRef<Type> /*ResultTypes*/,
549     ArrayRef<Type> /*argTypes*/, ValueRange args, OpBuilder* /*b*/) {
550   return args.front();
551 }
552 
553 template <>
554 inline Value mapMhloOpToStdScalarOp<mhlo::ComplexOp>(Location loc,
555                                                      ArrayRef<Type> resultTypes,
556                                                      ArrayRef<Type> argTypes,
557                                                      ValueRange args,
558                                                      OpBuilder* b) {
559   return MapMhloOpToScalarOpImpl<complex::CreateOp>{}(loc, resultTypes,
560                                                       argTypes, args, b);
561 }
562 
563 template <>
564 inline Value mapMhloOpToStdScalarOp<mhlo::RealOp>(Location loc,
565                                                   ArrayRef<Type> resultTypes,
566                                                   ArrayRef<Type> argTypes,
567                                                   ValueRange args,
568                                                   OpBuilder* b) {
569   if (!args[0].getType().isa<ComplexType>()) return args[0];
570   return MapMhloOpToScalarOpImpl<complex::ReOp>{}(loc, resultTypes, argTypes,
571                                                   args, b);
572 }
573 
574 template <>
575 inline Value mapMhloOpToStdScalarOp<mhlo::ImagOp>(Location loc,
576                                                   ArrayRef<Type> resultTypes,
577                                                   ArrayRef<Type> argTypes,
578                                                   ValueRange args,
579                                                   OpBuilder* b) {
580   if (!args[0].getType().isa<ComplexType>())
581     return b->create<arith::ConstantOp>(loc, b->getZeroAttr(args[0].getType()));
582   return MapMhloOpToScalarOpImpl<complex::ImOp>{}(loc, resultTypes, argTypes,
583                                                   args, b);
584 }
585 
586 // 'target_types' is the unconverted type (signed or unsigned if integer),
587 // 'ResultTypes' is the converted type (signless if integer).
588 inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef<Type> targetTypes,
589                                        ArrayRef<Type> resultTypes,
590                                        ArrayRef<Type> argTypes, ValueRange args,
591                                        OpBuilder* b) {
592   assert(targetTypes.size() == 1 && "ConvertOp should return a single result");
593   assert(resultTypes.size() == 1 && "ConvertOp should return a single result");
594   assert(argTypes.size() == 1 && "ConvertOp should take a single argument");
595   assert(args.size() == 1 && "ConvertOp should take a single argument");
596 
597   Type sourceType = getElementTypeOrSelf(argTypes.front());
598   Type targetType = getElementTypeOrSelf(targetTypes.front());
599   Type convertedSourceType = getElementTypeOrSelf(args.front());
600 
601   // A boolean value is considered to be unsigned when converting to
602   // floating-point. Otherwise, it will become `-1`.
603   if (IsUnsignedIntegerType{}(sourceType) &&
604       mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType,
605                                                targetType)) {
606     return b->create<mlir::arith::UIToFPOp>(loc, resultTypes, args, mlir::None);
607   }
608   if (mlir::arith::SIToFPOp::areCastCompatible(sourceType, targetType)) {
609     return b->create<mlir::arith::SIToFPOp>(loc, resultTypes, args, mlir::None);
610   }
611   if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
612     auto src = sourceType.cast<FloatType>();
613     auto res = targetType.cast<FloatType>();
614     if (src.getWidth() > res.getWidth()) {
615       return b->create<mlir::arith::TruncFOp>(loc, resultTypes, args,
616                                               mlir::None);
617     }
618     if (src.getWidth() < res.getWidth()) {
619       return b->create<mlir::arith::ExtFOp>(loc, resultTypes, args, mlir::None);
620     }
621     // There's no direct conversion between different 16 bit floating point
622     // types, so go through 32 bit float.
623     if (sourceType != targetType) {
624       assert(sourceType.isBF16() || targetType.isBF16());
625       Value ext = b->create<arith::ExtFOp>(loc, b->getF32Type(), args);
626       return b->create<arith::TruncFOp>(loc, resultTypes, ext);
627     }
628     // No conversion is needed for identical float types.
629     return args.front();
630   }
631   if (targetType.isInteger(/*width=*/1)) {
632     // When casting to bool, we need to compare whether the value is equal to
633     // zero.
634     if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) {
635       Value zeroIntval = b->create<arith::ConstantOp>(
636           loc, b->getZeroAttr(args.front().getType()));
637       return b->create<mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
638                                             args.front(), zeroIntval);
639     }
640     if (sourceType.isa<FloatType>()) {
641       Value zero = b->create<arith::ConstantOp>(
642           loc, b->getZeroAttr(args.front().getType()));
643       return b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
644                                             args.front(), zero);
645     }
646   }
647   if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
648     auto src = sourceType.cast<IntegerType>();
649     auto res = targetType.cast<IntegerType>();
650     if (src.getWidth() > res.getWidth()) {
651       return b->create<mlir::arith::TruncIOp>(loc, resultTypes, args,
652                                               mlir::None);
653     }
654     if (src.getWidth() < res.getWidth()) {
655       // Special case boolean values, so they get casted to `1` instead of `-1`.
656       if (IsUnsignedIntegerType{}(src)) {
657         return b->create<mlir::arith::ExtUIOp>(loc, resultTypes, args,
658                                                mlir::None);
659       }
660       return b->create<mlir::arith::ExtSIOp>(loc, resultTypes, args,
661                                              mlir::None);
662     }
663     // No conversion is needed for the same width integers
664     return args.front();
665   }
666   if (targetType.isUnsignedInteger() &&
667       mlir::arith::FPToUIOp::areCastCompatible(convertedSourceType,
668                                                targetType)) {
669     return b->create<mlir::arith::FPToUIOp>(loc, resultTypes, args, mlir::None);
670   }
671   if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType,
672                                                targetType)) {
673     return b->create<mlir::arith::FPToSIOp>(loc, resultTypes, args, mlir::None);
674   }
675   if (targetType.isa<ComplexType>()) {
676     Type targetElementType = targetType.cast<ComplexType>().getElementType();
677     assert(!targetElementType.isa<ComplexType>() &&
678            "elements of complex numbers should not be complex");
679     Value targetReal;
680     Value targetImag;
681     if (sourceType.isa<ComplexType>()) {
682       // We are converting from complex type: convert real and imaginary parts
683       // separately.
684       Type sourceElementType = sourceType.cast<ComplexType>().getElementType();
685       assert(!sourceElementType.isa<ComplexType>() &&
686              "elements of complex numbers should not be complex");
687       Value sourceReal =
688           b->create<mlir::complex::ReOp>(loc, sourceElementType, args.front());
689       targetReal =
690           mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType,
691                                     sourceElementType, sourceReal, b);
692       Value sourceImag =
693           b->create<mlir::complex::ImOp>(loc, sourceElementType, args.front());
694       targetImag =
695           mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType,
696                                     sourceElementType, sourceImag, b);
697     } else {
698       // We are converting from real (float, integer, etc.) type, convert the
699       // real part and set the imaginary part to 0.
700       targetReal = mapConvertOpToStdScalarOp(
701           loc, targetElementType, targetElementType, argTypes, args, b);
702       targetImag = b->create<mlir::arith::ConstantOp>(
703           loc, b->getFloatAttr(targetElementType, 0.0));
704     }
705     return b->create<mlir::complex::CreateOp>(loc, targetType, targetReal,
706                                               targetImag);
707   }
708   if (auto sourceComplexType = sourceType.dyn_cast<ComplexType>()) {
709     auto sourceElementType = sourceComplexType.getElementType();
710     // When converting from complex to a non-complex type, we take just the real
711     // part of the complex number.
712     Value sourceReal =
713         b->create<mlir::complex::ReOp>(loc, sourceElementType, args.front());
714     return mapConvertOpToStdScalarOp(loc, targetTypes, resultTypes,
715                                      sourceElementType, sourceReal, b);
716   }
717   return nullptr;
718 }
719 
720 template <>
721 inline Value mapMhloOpToStdScalarOp<mhlo::BitcastConvertOp>(
722     Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type>, ValueRange args,
723     OpBuilder* b) {
724   return b->create<mlir::arith::BitcastOp>(loc, resultTypes, args);
725 }
726 
727 template <>
728 inline Value mapMhloOpToStdScalarOp<mhlo::DotOp>(Location loc,
729                                                  ArrayRef<Type> resultTypes,
730                                                  ArrayRef<Type> argTypes,
731                                                  ValueRange args,
732                                                  OpBuilder* b) {
733   // Dot Op converter from lhlo to affine only accepts float and integer types.
734   const auto& lhs = args[0];
735   const auto& rhs = args[1];
736   const auto& result = args[2];
737   Type elementType = lhs.getType();
738   if (elementType.isa<FloatType>()) {
739     Value floatMul =
740         MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::arith::MulFOp>{}(
741             loc, resultTypes, argTypes, {lhs, rhs}, b);
742     return MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::arith::AddFOp>{}(
743         loc, resultTypes, argTypes, {floatMul, result}, b);
744   }
745   if (elementType.isa<IntegerType>()) {
746     Value intMul =
747         MapMhloOpToScalarOpImpl<IsAnyIntegerType, ::mlir::arith::MulIOp>{}(
748             loc, resultTypes, argTypes, {lhs, rhs}, b);
749     return MapMhloOpToScalarOpImpl<IsAnyIntegerType, ::mlir::arith::AddIOp>{}(
750         loc, resultTypes, argTypes, {intMul, result}, b);
751   }
752   return nullptr;
753 }
754 
755 template <>
756 inline Value mapMhloOpToStdScalarOp<mhlo::IsFiniteOp>(
757     Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/,
758     ValueRange args, OpBuilder* b) {
759   if (args[0].getType().isa<FloatType>()) {
760     auto posInf = APFloat::getInf(
761         args[0].getType().cast<FloatType>().getFloatSemantics());
762     auto constPosInf = b->create<arith::ConstantOp>(
763         loc, b->getFloatAttr(args[0].getType(), posInf));
764     Value absX = b->create<::mlir::math::AbsFOp>(loc, args[0]);
765     return b->create<::mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::ONE,
766                                             absX, constPosInf);
767   }
768   return nullptr;
769 }
770 
771 /// Implements the conversion of HLO op to scalar op (to use within region of a
772 /// linalg.generic op) for compare-select style operations like min/max.
773 template <typename... Args>
774 struct CompareSelectOpToStdScalarOp {
775   static Value map(Location /*loc*/,
776                    ComparisonDirection /*comparison_direction*/,
777                    ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/,
778                    ValueRange /*args*/, OpBuilder* /*b*/) {
779     return nullptr;
780   }
781 };
782 
783 /// Specialization which allows converting to a comparison operation in standard
784 /// dialect with a given predicate based on the element type of the operand.
785 template <typename SupportedType, typename StdCompareOp, typename Predicate,
786           typename... Args>
787 struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
788                                     Args...> {
789   static Value map(Location loc, ComparisonDirection comparisonDirection,
790                    ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes,
791                    ValueRange args, OpBuilder* b) {
792     Type elementType = getElementTypeOrSelf(argTypes.front());
793     if (elementType.isa<SupportedType>()) {
794       auto predicate = getCmpPredicate<Predicate>(
795           comparisonDirection, !elementType.isUnsignedInteger());
796       assert(predicate.has_value() && "expected valid comparison direction");
797       auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
798                                                   args[0], args[1]);
799       return b->create<::mlir::arith::SelectOp>(loc, cmp, args[0], args[1]);
800     }
801     return CompareSelectOpToStdScalarOp<Args...>::map(
802         loc, comparisonDirection, resultTypes, argTypes, args, b);
803   }
804 };
805 
806 inline Value mhloAlwaysPropagateNaN(Value v, ValueRange args, Location loc,
807                                     OpBuilder* b) {
808   Type elementType = getElementTypeOrSelf(args.front().getType());
809   if (auto floatType = elementType.dyn_cast<FloatType>()) {
810     Value isnan = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
811                                                  args[0], args[1]);
812 
813     auto nanApfloat = APFloat::getQNaN(floatType.getFloatSemantics());
814     Value nan = getConstantOrSplat(b, loc, args[0].getType(),
815                                    b->getFloatAttr(floatType, nanApfloat));
816     v = b->create<mlir::arith::SelectOp>(loc, isnan, nan, v);
817   }
818   return v;
819 }
820 
821 template <>
822 inline Value mapMhloOpToStdScalarOp<mhlo::LogisticOp>(
823     Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> /*argTypes*/,
824     ValueRange args, OpBuilder* b) {
825   auto ty = resultTypes.front().cast<FloatType>();
826   Value one = b->create<arith::ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
827   Value x = args.front();
828   Value negX = b->create<arith::NegFOp>(loc, x);
829   Value expNegX = b->create<::mlir::math::ExpOp>(loc, negX);
830   Value oneAddExpNegX = b->create<arith::AddFOp>(loc, one, expNegX);
831   return b->create<arith::DivFOp>(loc, one, oneAddExpNegX);
832 }
833 
834 template <>
835 inline Value mapMhloOpToStdScalarOp<mhlo::ClampOp>(Location loc,
836                                                    ArrayRef<Type> resultTypes,
837                                                    ArrayRef<Type> argTypes,
838                                                    ValueRange args,
839                                                    OpBuilder* b) {
840   mhlo::ClampOp::Adaptor op(args);
841   // clamp(lb, x, ub) = min(max(lb, x), ub)
842   Value maxLbX = mapMhloOpToStdScalarOp<mhlo::MaxOp>(
843       loc, resultTypes, argTypes, {op.min(), op.operand()}, b);
844   return mapMhloOpToStdScalarOp<mhlo::MinOp>(loc, resultTypes, argTypes,
845                                              {maxLbX, op.max()}, b);
846 }
847 
848 template <typename U, typename S>
849 inline Value makeSafeIntDiv(ImplicitLocOpBuilder& lb, Type originalType,
850                             Value lhs, Value rhs, Value returnedOnZero,
851                             Value returnedOnSignedOverflow) {
852   Type type = lhs.getType();
853   auto elementType = getElementTypeOrSelf(type).cast<IntegerType>();
854   Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type));
855   auto makeConstant = [&](const APInt& i) {
856     return getConstantOrSplat(&lb, lb.getLoc(), type,
857                               lb.getIntegerAttr(elementType, i));
858   };
859   Value one = makeConstant(APInt(elementType.getWidth(), 1));
860   Value rhsIsZero =
861       lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, rhs, zero);
862 
863   // For unsigned just set the divisor to 1 when it would be 0.
864   if (originalType.isUnsignedInteger()) {
865     Value safeRhs = lb.create<arith::SelectOp>(rhsIsZero, one, rhs);
866     Value safeDiv = lb.create<U>(lhs, safeRhs);
867     return lb.create<arith::SelectOp>(rhsIsZero, returnedOnZero, safeDiv);
868   }
869 
870   // For signed also check for INT_MIN / -1.
871   Value smin = makeConstant(APInt::getSignedMinValue(elementType.getWidth()));
872   Value lhsIsSmin =
873       lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, lhs, smin);
874   Value minusOne = makeConstant(APInt::getAllOnesValue(elementType.getWidth()));
875   Value rhsIsMinusOne =
876       lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, rhs, minusOne);
877   Value hasIntMinOverflow = lb.create<arith::AndIOp>(lhsIsSmin, rhsIsMinusOne);
878   Value rhsIsUnsafe = lb.create<arith::OrIOp>(rhsIsZero, hasIntMinOverflow);
879   Value safeRhs = lb.create<arith::SelectOp>(rhsIsUnsafe, one, rhs);
880   Value safeDiv = lb.create<S>(lhs, safeRhs);
881   Value safeSmin = lb.create<arith::SelectOp>(
882       hasIntMinOverflow, returnedOnSignedOverflow, safeDiv);
883   return lb.create<arith::SelectOp>(rhsIsZero, returnedOnZero, safeSmin);
884 }
885 
886 template <>
887 inline Value mapMhloOpToStdScalarOp<mhlo::DivOp>(Location loc,
888                                                  ArrayRef<Type> resultTypes,
889                                                  ArrayRef<Type> argTypes,
890                                                  ValueRange args,
891                                                  OpBuilder* b) {
892   Type originalType = getElementTypeOrSelf(argTypes.front());
893   if (originalType.isa<ComplexType, FloatType>()) {
894     return MapMhloOpToScalarOpImpl<IsFloatType, arith::DivFOp, IsComplexType,
895                                    complex::DivOp>{}(loc, resultTypes, argTypes,
896                                                      args, b);
897   }
898 
899   // Integer division overflow behavior:
900   //
901   // X / 0 == -1
902   // INT_SMIN /s -1 = INT_SMIN
903   ImplicitLocOpBuilder lb(loc, *b);
904   Type type = args.front().getType();
905   auto elementType = getElementTypeOrSelf(type).cast<IntegerType>();
906   auto makeConstant = [&](const APInt& i) {
907     return getConstantOrSplat(&lb, lb.getLoc(), type,
908                               lb.getIntegerAttr(elementType, i));
909   };
910   Value minusOne = makeConstant(APInt::getAllOnesValue(elementType.getWidth()));
911   Value smin = makeConstant(APInt::getSignedMinValue(elementType.getWidth()));
912   return makeSafeIntDiv<arith::DivUIOp, arith::DivSIOp>(
913       lb, originalType, args[0], args[1], /*returnedOnZero=*/minusOne,
914       /*returnedOnSignedOverflow=*/smin);
915 }
916 
917 template <>
918 inline Value mapMhloOpToStdScalarOp<mhlo::RemOp>(Location loc,
919                                                  ArrayRef<Type> resultTypes,
920                                                  ArrayRef<Type> argTypes,
921                                                  ValueRange args,
922                                                  OpBuilder* b) {
923   Type originalType = getElementTypeOrSelf(argTypes.front());
924   if (originalType.isa<ComplexType, FloatType>()) {
925     return MapMhloOpToScalarOpImpl<IsFloatType, arith::RemFOp>{}(
926         loc, resultTypes, argTypes, args, b);
927   }
928 
929   // Integer remainder overflow behavior:
930   //
931   // X % 0 == X
932   // INT_SMIN %s -1 = 0
933   ImplicitLocOpBuilder lb(loc, *b);
934   Type type = args.front().getType();
935   Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type));
936   return makeSafeIntDiv<arith::RemUIOp, arith::RemSIOp>(
937       lb, originalType, args[0], args[1], /*returnedOnZero=*/args[0],
938       /*returnedOnSignedOverflow=*/zero);
939 }
940 
941 template <>
942 inline Value mapMhloOpToStdScalarOp<mhlo::NegOp>(Location loc,
943                                                  ArrayRef<Type> resultTypes,
944                                                  ArrayRef<Type> argTypes,
945                                                  ValueRange args,
946                                                  OpBuilder* b) {
947   Type elementType = getElementTypeOrSelf(args.front().getType());
948   if (elementType.isa<ComplexType, FloatType>()) {
949     return MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::arith::NegFOp,
950                                    IsComplexType, ::mlir::complex::NegOp>{}(
951         loc, resultTypes, argTypes, args, b);
952   }
953   if (elementType.isa<IntegerType>()) {
954     // lmhlo.neg(x, result) -> result = sub(0, x)
955     Value lhs = args[0];
956     Value zeroIntval =
957         b->create<arith::ConstantOp>(loc, b->getZeroAttr(lhs.getType()));
958     return b->create<ScalarIOp<mhlo::SubtractOp>>(loc, zeroIntval, lhs);
959   }
960   return nullptr;
961 }
962 
963 template <>
964 inline Value mapMhloOpToStdScalarOp<mhlo::NotOp>(Location loc,
965                                                  ArrayRef<Type> /*ResultTypes*/,
966                                                  ArrayRef<Type> /*argTypes*/,
967                                                  ValueRange args,
968                                                  OpBuilder* b) {
969   Type elementType = getElementTypeOrSelf(args.front().getType());
970   if (auto integerType = elementType.dyn_cast<IntegerType>()) {
971     // lmhlo.not(x) -> x ^ -1
972     Value allOnes = getConstantOrSplat(
973         b, loc, args[0].getType(),
974         b->getIntegerAttr(integerType,
975                           APInt::getAllOnesValue(integerType.getWidth())));
976     return b->create<::mlir::arith::XOrIOp>(loc, allOnes, args[0]);
977   }
978   return nullptr;
979 }
980 
981 template <>
982 inline Value mapMhloOpToStdScalarOp<mhlo::PowOp>(Location loc,
983                                                  ArrayRef<Type> resultTypes,
984                                                  ArrayRef<Type> argTypes,
985                                                  ValueRange args,
986                                                  OpBuilder* b) {
987   mhlo::PowOp::Adaptor adaptor(args);
988   auto lb = ImplicitLocOpBuilder(loc, *b);
989   // Floating point can use std::powf
990   auto resultType = resultTypes.front();
991   if (resultType.isa<ComplexType, FloatType>()) {
992     return MapMhloOpToScalarOpImpl<IsFloatType, math::PowFOp, IsComplexType,
993                                    complex::PowOp>{}(loc, resultTypes, argTypes,
994                                                      args, b);
995   }
996 
997   // Exponentiation by squaring:
998   // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
999   Value negOne =
1000       lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, -1));
1001   Value zero = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 0));
1002   Value one = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 1));
1003   Value two = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 2));
1004   Value step = lb.create<arith::ConstantIndexOp>(1);
1005   Value lowerBound = lb.create<arith::ConstantIndexOp>(0);
1006   // Everything else would overflow for any exponent > 1, as 2^64
1007   // is the larget possible exponent for a 64-bit integer, and
1008   // that's 1 << 6.
1009   Value upperBound = lb.create<arith::ConstantIndexOp>(6);
1010   auto originalBase = adaptor.lhs();
1011   auto originalExponent = adaptor.rhs();
1012 
1013   Value accum =
1014       lb.create<scf::ForOp>(
1015             lowerBound, upperBound, step,
1016             SmallVector<Value>({one, originalBase, originalExponent}),
1017             [&](OpBuilder& b, Location, Value /*v*/, ValueRange iters) {
1018               Value accum = iters[0];
1019               Value base = iters[1];
1020               Value exponent = iters[2];
1021 
1022               Value condition = b.create<arith::CmpIOp>(
1023                   loc, arith::CmpIPredicate::eq,
1024                   b.create<::mlir::arith::AndIOp>(loc, exponent, one), one);
1025               Value multiplied =
1026                   b.create<::mlir::arith::MulIOp>(loc, accum, base);
1027               accum = b.create<::mlir::arith::SelectOp>(loc, condition,
1028                                                         multiplied, accum);
1029               base = b.create<::mlir::arith::MulIOp>(loc, base, base);
1030               exponent = b.create<::mlir::arith::ShRUIOp>(loc, exponent, one);
1031               b.create<scf::YieldOp>(
1032                   loc, SmallVector<Value>({accum, base, exponent}));
1033             })
1034           .getResult(0);
1035 
1036   Value rhsIsEven = lb.create<arith::CmpIOp>(
1037       arith::CmpIPredicate::eq, lb.create<arith::RemSIOp>(adaptor.rhs(), two),
1038       zero);
1039   Value rhsIsNegative =
1040       lb.create<arith::CmpIOp>(arith::CmpIPredicate::slt, adaptor.rhs(), zero);
1041   Value lhsIsOne =
1042       lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, adaptor.lhs(), one);
1043   Value lhsIsNegOne =
1044       lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, adaptor.lhs(), negOne);
1045 
1046   // The accum is correct when the rhs is non-negative. When rhs is
1047   // negative, we return 0 for integer, with the exception of lhs values of 1
1048   // and -1 which have integer results for negative exponents. Specifically, the
1049   // calulation is the following:
1050   //
1051   // - Return accum if the rhs is not negative.
1052   // - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
1053   // - Return 1 if lhs is 1.
1054   // - Else return 0.
1055   Value ifLhsIsOne = lb.create<::mlir::arith::SelectOp>(lhsIsOne, one, zero);
1056   Value ifLhsIsNegOne = lb.create<::mlir::arith::SelectOp>(
1057       lhsIsNegOne, lb.create<::mlir::arith::SelectOp>(rhsIsEven, one, negOne),
1058       ifLhsIsOne);
1059   return lb.create<::mlir::arith::SelectOp>(rhsIsNegative, ifLhsIsNegOne,
1060                                             accum);
1061 }
1062 
1063 template <>
1064 inline Value mapMhloOpToStdScalarOp<mhlo::SelectOp>(Location loc,
1065                                                     ArrayRef<Type> resultTypes,
1066                                                     ArrayRef<Type> argTypes,
1067                                                     ValueRange args,
1068                                                     OpBuilder* b) {
1069   return MapMhloOpToScalarOpImpl<::mlir::arith::SelectOp>{}(loc, resultTypes,
1070                                                             argTypes, args, b);
1071 }
1072 
1073 template <>
1074 inline Value mapMhloOpToStdScalarOp<mhlo::SignOp>(Location loc,
1075                                                   ArrayRef<Type> resultTypes,
1076                                                   ArrayRef<Type> /*argTypes*/,
1077                                                   ValueRange args,
1078                                                   OpBuilder* b) {
1079   Type elementType = getElementTypeOrSelf(args.front().getType());
1080   if (auto floatType = elementType.dyn_cast<FloatType>()) {
1081     Value zero =
1082         b->create<arith::ConstantOp>(loc, b->getZeroAttr(args[0].getType()));
1083     Value ne0I1 = b->create<::mlir::arith::CmpFOp>(
1084         loc, arith::CmpFPredicate::ONE, args[0], zero);
1085     Value ne0Float =
1086         b->create<::mlir::arith::UIToFPOp>(loc, zero.getType(), ne0I1);
1087     Value copySign = b->create<::mlir::math::CopySignOp>(loc, resultTypes,
1088                                                          ne0Float, args[0]);
1089     auto isNan = b->create<::mlir::arith::CmpFOp>(
1090         loc, arith::CmpFPredicate::UNO, args[0], args[0]);
1091     return b->create<::mlir::arith::SelectOp>(loc, isNan, args[0], copySign);
1092   }
1093   if (auto integerType = elementType.dyn_cast<IntegerType>()) {
1094     // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
1095     Value zero =
1096         b->create<arith::ConstantOp>(loc, b->getZeroAttr(args[0].getType()));
1097     Value bitwidthMinusOne = getConstantOrSplat(
1098         b, loc, args[0].getType(),
1099         b->getIntegerAttr(integerType, integerType.getWidth() - 1));
1100     Value one = getConstantOrSplat(b, loc, args[0].getType(),
1101                                    b->getIntegerAttr(integerType, 1));
1102     Value cmp = b->create<::mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1103                                                  args[0], zero);
1104     Value ashr =
1105         b->create<::mlir::arith::ShRSIOp>(loc, args[0], bitwidthMinusOne);
1106     Value orOp = b->create<::mlir::arith::OrIOp>(loc, ashr, one);
1107     return b->create<::mlir::arith::SelectOp>(loc, cmp, zero, orOp);
1108   }
1109   if (elementType.isa<ComplexType>()) {
1110     return b->create<::mlir::complex::SignOp>(loc, elementType, args.front());
1111   }
1112   return nullptr;
1113 }
1114 
1115 }  // namespace impl
1116 
1117 struct MhloOpToStdScalarOp {
1118   // Converts mhlo 'op' to linalg and arith ops.
1119   template <typename MhloOpTy>
1120   static Value mapOp(MhloOpTy op, ArrayRef<Type> resultTypes, ValueRange args,
1121                      OpBuilder* b) {
1122     auto argTypes = llvm::to_vector(op->getOperandTypes());
1123     return mapOpWithArgTypes(op, resultTypes, argTypes, args, b);
1124   }
1125 
1126   // Converts mhlo 'op' to linalg and arith ops. The types of 'args' may already
1127   // be converted, 'argTypes' are their original types.
1128   template <typename MhloOpTy>
1129   static Value mapOpWithArgTypes(MhloOpTy op, ArrayRef<Type> resultTypes,
1130                                  ArrayRef<Type> argTypes, ValueRange args,
1131                                  OpBuilder* b) {
1132     static_assert(!std::is_same<MhloOpTy, mhlo::ConvertOp>::value);
1133     return mapOpOfType<MhloOpTy>(op.getLoc(), resultTypes, argTypes, args, b);
1134   }
1135   // Overload for mhlo::ReducePrecisionOp.
1136   static Value mapOpWithArgTypes(mhlo::ReducePrecisionOp op,
1137                                  ArrayRef<Type> /*ResultTypes*/,
1138                                  ArrayRef<Type> argTypes, ValueRange args,
1139                                  OpBuilder* b) {
1140     return impl::mapReducePrecisionOpToStdScalarOp(
1141         op.getLoc(), argTypes, args, b, op.exponent_bits(), op.mantissa_bits());
1142   }
1143   // Overload for mhlo::CompareOp.
1144   static Value mapOpWithArgTypes(mhlo::CompareOp op, ArrayRef<Type> resultTypes,
1145                                  ArrayRef<Type> argTypes, ValueRange args,
1146                                  OpBuilder* b) {
1147     auto comparisonDirection = op.comparison_direction();
1148     return impl::mapCompareOpToStdScalarOp(op.getLoc(), comparisonDirection,
1149                                            resultTypes, argTypes, args, b);
1150   }
1151   // Overload for mhlo::ConvertOp.
1152   static Value mapOpWithArgTypes(mhlo::ConvertOp op, ArrayRef<Type> resultTypes,
1153                                  ArrayRef<Type> argTypes, ValueRange args,
1154                                  OpBuilder* b) {
1155     return impl::mapConvertOpToStdScalarOp(op.getLoc(), op.getType(),
1156                                            resultTypes, argTypes, args, b);
1157   }
1158 
1159   // Converts mhlo 'op' (except mhlo::CompareOp) to linalg and arith ops.
1160   template <typename MhloOpTy>
1161   static Value mapOpOfType(Location loc, ArrayRef<Type> resultTypes,
1162                            ArrayRef<Type> argTypes, ValueRange args,
1163                            OpBuilder* b) {
1164     static_assert(!std::is_same<MhloOpTy, mhlo::CompareOp>::value, "invalid");
1165     if (std::is_same<MhloOpTy, mhlo::ConvertOp>::value) {
1166       // Note: this assumes that the caller is passing result/arg types with
1167       // appropriate signedness.
1168       return impl::mapConvertOpToStdScalarOp(loc, resultTypes, resultTypes,
1169                                              argTypes, args, b);
1170     }
1171     return impl::mapMhloOpToStdScalarOp<MhloOpTy>(loc, resultTypes, argTypes,
1172                                                   args, b);
1173   }
1174 };
1175 
1176 }  // namespace mhlo
1177 }  // namespace mlir
1178 
1179 #endif  // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H
1180