1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H 17 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H 18 19 #include "llvm/ADT/ArrayRef.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 24 #include "mlir/Dialect/Complex/IR/Complex.h" 25 #include "mlir/Dialect/Math/IR/Math.h" 26 #include "mlir/Dialect/SCF/IR/SCF.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/IR/BuiltinTypes.h" 29 #include "mlir/IR/ImplicitLocOpBuilder.h" 30 #include "mlir/IR/TypeUtilities.h" 31 32 namespace mlir { 33 namespace mhlo { 34 namespace impl { 35 36 // A struct to map MhloBinaryOpTy type to the corresponding floating-point and 37 // integer scalar operation types. 38 template <typename MhloBinaryOpTy> 39 struct MhloToScalarOp { 40 using FOp = void; 41 using IOp = void; 42 using UOp = void; 43 using COp = void; 44 }; 45 46 template <> 47 struct MhloToScalarOp<mhlo::AddOp> { 48 using FOp = ::mlir::arith::AddFOp; 49 using IOp = ::mlir::arith::AddIOp; 50 using UOp = ::mlir::arith::AddIOp; 51 using COp = ::mlir::complex::AddOp; 52 }; 53 template <> 54 struct MhloToScalarOp<mhlo::AndOp> { 55 using IOp = ::mlir::arith::AndIOp; 56 using UOp = ::mlir::arith::AndIOp; 57 }; 58 template <> 59 struct MhloToScalarOp<mhlo::CompareOp> { 60 using FOp = ::mlir::arith::CmpFOp; 61 using IOp = ::mlir::arith::CmpIOp; 62 using UOp = ::mlir::arith::CmpIOp; 63 }; 64 template <> 65 struct MhloToScalarOp<mhlo::CeilOp> { 66 using FOp = ::mlir::math::CeilOp; 67 }; 68 template <> 69 struct MhloToScalarOp<mhlo::ClzOp> { 70 using IOp = ::mlir::math::CountLeadingZerosOp; 71 using UOp = ::mlir::math::CountLeadingZerosOp; 72 }; 73 template <> 74 struct MhloToScalarOp<mhlo::CosineOp> { 75 using FOp = ::mlir::math::CosOp; 76 using COp = ::mlir::complex::CosOp; 77 }; 78 template <> 79 struct MhloToScalarOp<mhlo::ExpOp> { 80 using FOp = ::mlir::math::ExpOp; 81 using COp = ::mlir::complex::ExpOp; 82 }; 83 template <> 84 struct MhloToScalarOp<mhlo::Expm1Op> { 85 using FOp = ::mlir::math::ExpM1Op; 86 using COp = ::mlir::complex::Expm1Op; 87 }; 88 template <> 89 struct MhloToScalarOp<mhlo::FloorOp> { 90 using FOp = ::mlir::math::FloorOp; 91 }; 92 template <> 93 struct MhloToScalarOp<mhlo::MaxOp> { 94 using FOp = ::mlir::arith::MaxFOp; 95 using IOp = ::mlir::arith::MaxSIOp; 96 using UOp = ::mlir::arith::MaxUIOp; 97 }; 98 template <> 99 struct MhloToScalarOp<mhlo::MinOp> { 100 using FOp = ::mlir::arith::MinFOp; 101 using IOp = ::mlir::arith::MinSIOp; 102 using UOp = ::mlir::arith::MinUIOp; 103 }; 104 template <> 105 struct MhloToScalarOp<mhlo::LogOp> { 106 using FOp = ::mlir::math::LogOp; 107 using COp = ::mlir::complex::LogOp; 108 }; 109 template <> 110 struct MhloToScalarOp<mhlo::Log1pOp> { 111 using FOp = ::mlir::math::Log1pOp; 112 using COp = ::mlir::complex::Log1pOp; 113 }; 114 template <> 115 struct MhloToScalarOp<mhlo::MulOp> { 116 using FOp = ::mlir::arith::MulFOp; 117 using IOp = ::mlir::arith::MulIOp; 118 using UOp = ::mlir::arith::MulIOp; 119 using COp = ::mlir::complex::MulOp; 120 }; 121 template <> 122 struct MhloToScalarOp<mhlo::OrOp> { 123 using IOp = ::mlir::arith::OrIOp; 124 using UOp = ::mlir::arith::OrIOp; 125 }; 126 template <> 127 struct MhloToScalarOp<mhlo::PopulationCountOp> { 128 using IOp = ::mlir::math::CtPopOp; 129 using UOp = ::mlir::math::CtPopOp; 130 }; 131 template <> 132 struct MhloToScalarOp<mhlo::RsqrtOp> { 133 using FOp = ::mlir::math::RsqrtOp; 134 using COp = ::mlir::complex::RsqrtOp; 135 }; 136 template <> 137 struct MhloToScalarOp<mhlo::RoundOp> { 138 using FOp = ::mlir::math::RoundOp; 139 }; 140 template <> 141 struct MhloToScalarOp<mhlo::SubtractOp> { 142 using FOp = ::mlir::arith::SubFOp; 143 using IOp = ::mlir::arith::SubIOp; 144 using UOp = ::mlir::arith::SubIOp; 145 using COp = ::mlir::complex::SubOp; 146 }; 147 template <> 148 struct MhloToScalarOp<mhlo::SqrtOp> { 149 using FOp = ::mlir::math::SqrtOp; 150 using COp = ::mlir::complex::SqrtOp; 151 }; 152 template <> 153 struct MhloToScalarOp<mhlo::SineOp> { 154 using FOp = ::mlir::math::SinOp; 155 using COp = ::mlir::complex::SinOp; 156 }; 157 template <> 158 struct MhloToScalarOp<mhlo::ShiftLeftOp> { 159 using IOp = ::mlir::arith::ShLIOp; 160 using UOp = ::mlir::arith::ShLIOp; 161 }; 162 template <> 163 struct MhloToScalarOp<mhlo::ShiftRightArithmeticOp> { 164 using IOp = ::mlir::arith::ShRSIOp; 165 using UOp = ::mlir::arith::ShRSIOp; 166 }; 167 template <> 168 struct MhloToScalarOp<mhlo::ShiftRightLogicalOp> { 169 using IOp = ::mlir::arith::ShRUIOp; 170 using UOp = ::mlir::arith::ShRUIOp; 171 }; 172 template <> 173 struct MhloToScalarOp<mhlo::Atan2Op> { 174 using FOp = ::mlir::math::Atan2Op; 175 using COp = ::mlir::complex::Atan2Op; 176 }; 177 template <> 178 struct MhloToScalarOp<mhlo::TanhOp> { 179 using FOp = ::mlir::math::TanhOp; 180 using COp = ::mlir::complex::TanhOp; 181 }; 182 template <> 183 struct MhloToScalarOp<mhlo::XorOp> { 184 using IOp = ::mlir::arith::XOrIOp; 185 using UOp = ::mlir::arith::XOrIOp; 186 }; 187 188 // Alias for the map from MHLO binary op type to STD floating-point op type. 189 template <typename MhloOp> 190 using ScalarFOp = typename MhloToScalarOp<MhloOp>::FOp; 191 // Alias for the map from MHLO binary op type to STD signed integer op type. 192 template <typename MhloOp> 193 using ScalarIOp = typename MhloToScalarOp<MhloOp>::IOp; 194 // Alias for the map from MHLO binary op type to STD unsigned integer op type. 195 template <typename MhloOp> 196 using ScalarUOp = typename MhloToScalarOp<MhloOp>::UOp; 197 // Alias for the map from MHLO binary op type to STD complex op type. 198 template <typename MhloOp> 199 using ScalarCOp = typename MhloToScalarOp<MhloOp>::COp; 200 201 template <typename... Args> 202 struct MapMhloOpToScalarOpImpl { 203 Value operator()(Location /*loc*/, ArrayRef<Type> /*ResultTypes*/, 204 ArrayRef<Type> /*argTypes*/, ValueRange /*args*/, 205 OpBuilder* /*b*/) { 206 return nullptr; 207 } 208 }; 209 210 template <typename StdScalarOp> 211 struct MapMhloOpToScalarOpImpl<StdScalarOp> { 212 Value operator()(Location loc, ArrayRef<Type> resultTypes, 213 ArrayRef<Type> /*argTypes*/, ValueRange args, OpBuilder* b) { 214 return b->template create<StdScalarOp>(loc, resultTypes, args, mlir::None); 215 } 216 }; 217 218 template <typename SupportedType, typename StdScalarOp, typename... Args> 219 struct MapMhloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> { 220 Value operator()(Location loc, ArrayRef<Type> resultTypes, 221 ArrayRef<Type> argTypes, ValueRange args, OpBuilder* b) { 222 Type elementType = getElementTypeOrSelf(argTypes.front()); 223 if (SupportedType{}(elementType)) { 224 return b->template create<StdScalarOp>(loc, resultTypes, args, 225 mlir::None); 226 } 227 return MapMhloOpToScalarOpImpl<Args...>{}(loc, resultTypes, argTypes, args, 228 b); 229 } 230 }; 231 232 template <typename SupportedType, typename... Args> 233 struct MapMhloOpToScalarOpImpl<SupportedType, void, Args...> { 234 Value operator()(Location loc, ArrayRef<Type> resultTypes, 235 ArrayRef<Type> argTypes, ValueRange args, OpBuilder* b) { 236 return MapMhloOpToScalarOpImpl<Args...>{}(loc, resultTypes, argTypes, args, 237 b); 238 } 239 }; 240 241 struct IsAnyIntegerType { 242 bool operator()(Type t) { return t.isa<IntegerType>(); } 243 }; 244 245 struct IsSignedIntegerType { 246 bool operator()(Type t) { 247 // Pretend that signless is signed. This will change eventually. 248 return t.isa<IntegerType>() && !t.isUnsignedInteger() && 249 !t.isSignlessInteger(1); 250 } 251 }; 252 253 struct IsUnsignedIntegerType { 254 bool operator()(Type t) { 255 return t.isUnsignedInteger() || t.isSignlessInteger(1); 256 } 257 }; 258 259 struct IsFloatType { 260 bool operator()(Type t) { return t.isa<FloatType>(); } 261 }; 262 263 struct IsComplexType { 264 bool operator()(Type t) { return t.isa<ComplexType>(); } 265 }; 266 267 template <template <typename T> class MapTy, typename OpTy, 268 typename PredTy = llvm::is_detected<MapTy, OpTy>> 269 struct MapableIf { 270 using type = void; 271 }; 272 template <template <typename T> class MapTy, typename OpTy> 273 struct MapableIf<MapTy, OpTy, std::true_type> { 274 using type = MapTy<OpTy>; 275 }; 276 277 // Inserts the computation that corresponds to the body of the loop for lowered 278 // MHLO unary/binary op. Returns the value for the result. 279 template <typename MhloOpTy> 280 inline Value mapMhloOpToStdScalarOp(Location loc, ArrayRef<Type> resultTypes, 281 ArrayRef<Type> argTypes, ValueRange args, 282 OpBuilder* b) { 283 using ScalarIOpOrVoid = typename MapableIf<ScalarIOp, MhloOpTy>::type; 284 using ScalarUOpOrVoid = typename MapableIf<ScalarUOp, MhloOpTy>::type; 285 using ScalarFOpOrVoid = typename MapableIf<ScalarFOp, MhloOpTy>::type; 286 using ScalarCOpOrVoid = typename MapableIf<ScalarCOp, MhloOpTy>::type; 287 return MapMhloOpToScalarOpImpl<IsSignedIntegerType, ScalarIOpOrVoid, 288 IsUnsignedIntegerType, ScalarUOpOrVoid, 289 IsFloatType, ScalarFOpOrVoid, IsComplexType, 290 ScalarCOpOrVoid>{}(loc, resultTypes, argTypes, 291 args, b); 292 } 293 294 template <> 295 inline Value mapMhloOpToStdScalarOp<mhlo::AbsOp>(Location loc, 296 ArrayRef<Type> resultTypes, 297 ArrayRef<Type> argTypes, 298 ValueRange args, 299 OpBuilder* b) { 300 Type elementType = getElementTypeOrSelf(argTypes.front()); 301 if (elementType.isa<FloatType>()) { 302 return MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::math::AbsFOp>{}( 303 loc, resultTypes, argTypes, args, b); 304 } 305 if (elementType.isa<ComplexType>()) { 306 return MapMhloOpToScalarOpImpl<IsComplexType, ::mlir::complex::AbsOp>{}( 307 loc, resultTypes, argTypes, args, b); 308 } 309 if (elementType.isSignlessInteger() || elementType.isSignedInteger()) { 310 // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) 311 Value lhs = args[0]; 312 Value zeroIntval = 313 b->create<arith::ConstantOp>(loc, b->getZeroAttr(lhs.getType())); 314 auto lhsGtZero = b->create<ScalarIOp<CompareOp>>( 315 loc, arith::CmpIPredicate::sge, lhs, zeroIntval); 316 auto negVal = b->create<ScalarIOp<mhlo::SubtractOp>>(loc, zeroIntval, lhs); 317 return b->create<::mlir::arith::SelectOp>(loc, lhsGtZero, lhs, negVal); 318 } 319 return nullptr; 320 } 321 322 // Return a constant for v of type t, splat if t is a vector type. 323 inline Value getConstantOrSplat(OpBuilder* b, Location loc, Type t, 324 Attribute v) { 325 if (VectorType vecType = t.dyn_cast<VectorType>()) { 326 v = SplatElementsAttr::get(vecType, v); 327 } 328 return b->create<arith::ConstantOp>(loc, t, v); 329 } 330 331 template <> 332 inline Value mapMhloOpToStdScalarOp<mhlo::CbrtOp>(Location loc, 333 ArrayRef<Type> resultTypes, 334 ArrayRef<Type> argTypes, 335 ValueRange args, 336 OpBuilder* b) { 337 mhlo::CbrtOp::Adaptor adaptor(args); 338 Type elementType = getElementTypeOrSelf(argTypes.front()); 339 if (auto floatType = elementType.dyn_cast<FloatType>()) { 340 // Convert cbrt(x) to copysign(cbrt(abs(x), 1.0 / 3.0), x). 341 // This is to allow cbrt using pow while still handling negative numbers. It 342 // should match most cbrt intrinsics. 343 Value abs = b->create<mlir::math::AbsFOp>(loc, adaptor.operand()); 344 Value third = b->create<arith::ConstantOp>( 345 loc, b->getFloatAttr(floatType, 1.0 / 3.0)); 346 Value pow = b->create<mlir::math::PowFOp>(loc, resultTypes[0], abs, third); 347 return b->create<mlir::math::CopySignOp>(loc, floatType, pow, 348 adaptor.operand()); 349 } 350 return nullptr; 351 } 352 353 template <typename PredicateType> 354 inline Optional<PredicateType> getCmpPredicate(mhlo::ComparisonDirection, 355 bool) { 356 return llvm::None; 357 } 358 359 template <> 360 inline Optional<arith::CmpFPredicate> getCmpPredicate<arith::CmpFPredicate>( 361 mhlo::ComparisonDirection comparisonDirection, bool isSigned) { 362 assert(isSigned && "cannot have an unsigned float!"); 363 return llvm::StringSwitch<Optional<arith::CmpFPredicate>>( 364 stringifyComparisonDirection(comparisonDirection)) 365 .Case("EQ", arith::CmpFPredicate::OEQ) 366 .Case("NE", arith::CmpFPredicate::UNE) 367 .Case("GE", arith::CmpFPredicate::OGE) 368 .Case("GT", arith::CmpFPredicate::OGT) 369 .Case("LE", arith::CmpFPredicate::OLE) 370 .Case("LT", arith::CmpFPredicate::OLT) 371 .Default(llvm::None); 372 } 373 374 template <> 375 inline Optional<arith::CmpIPredicate> getCmpPredicate<arith::CmpIPredicate>( 376 mhlo::ComparisonDirection comparisonDirection, bool isSigned) { 377 return llvm::StringSwitch<Optional<arith::CmpIPredicate>>( 378 stringifyComparisonDirection(comparisonDirection)) 379 .Case("EQ", arith::CmpIPredicate::eq) 380 .Case("NE", arith::CmpIPredicate::ne) 381 .Case("GE", 382 isSigned ? arith::CmpIPredicate::sge : arith::CmpIPredicate::uge) 383 .Case("GT", 384 isSigned ? arith::CmpIPredicate::sgt : arith::CmpIPredicate::ugt) 385 .Case("LE", 386 isSigned ? arith::CmpIPredicate::sle : arith::CmpIPredicate::ule) 387 .Case("LT", 388 isSigned ? arith::CmpIPredicate::slt : arith::CmpIPredicate::ult) 389 .Default(llvm::None); 390 } 391 392 inline Value mapCompareOpToStdScalarOp(Location loc, 393 ComparisonDirection comparisonDirection, 394 ArrayRef<Type> /*ResultTypes*/, 395 ArrayRef<Type> argTypes, ValueRange args, 396 OpBuilder* b) { 397 const auto& lhs = args[0]; 398 const auto& rhs = args[1]; 399 Type elementType = getElementTypeOrSelf(argTypes.front()); 400 if (elementType.isa<IntegerType>()) { 401 bool isUnsigned = IsUnsignedIntegerType{}(elementType); 402 Optional<arith::CmpIPredicate> predicate = 403 getCmpPredicate<arith::CmpIPredicate>(comparisonDirection, !isUnsigned); 404 assert(predicate.has_value() && "expected valid comparison direction"); 405 return b->create<ScalarIOp<mhlo::CompareOp>>(loc, predicate.value(), lhs, 406 rhs); 407 } 408 if (elementType.isa<FloatType>()) { 409 Optional<arith::CmpFPredicate> predicate = 410 getCmpPredicate<arith::CmpFPredicate>(comparisonDirection, 411 /*is_signed=*/true); 412 assert(predicate.has_value() && "expected valid comparison direction"); 413 return b->create<ScalarFOp<mhlo::CompareOp>>(loc, predicate.value(), lhs, 414 rhs); 415 } 416 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 417 if (complexType.getElementType().isa<FloatType>()) { 418 if (comparisonDirection == ComparisonDirection::EQ) { 419 return b->create<complex::EqualOp>(loc, lhs, rhs); 420 } 421 if (comparisonDirection == ComparisonDirection::NE) { 422 return b->create<complex::NotEqualOp>(loc, lhs, rhs); 423 } 424 } 425 } 426 return nullptr; 427 } 428 429 inline Value mapReducePrecisionOpToStdScalarOp( 430 Location loc, ArrayRef<Type> argTypes, ValueRange args, OpBuilder* builder, 431 int destExponentBits, int destMantissaBits) { 432 using llvm::APInt; 433 mlir::ImplicitLocOpBuilder b(loc, *builder); 434 435 // Integer and float types for casting and constant generation. 436 auto floatType = 437 argTypes.front().cast<TensorType>().getElementType().cast<FloatType>(); 438 int64_t nbits = floatType.getWidth(); 439 auto intType = mlir::IntegerType::get(loc.getContext(), floatType.getWidth()); 440 441 Value xAsInt = b.create<arith::BitcastOp>(intType, args[0]); 442 443 // SignificandWidth includes the implicit extra bit. 444 auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; 445 int srcExponentBits = nbits - 1 - srcMantissaBits; 446 447 // Clear the sign bit, it does not participate in rounding and we will restore 448 // it later. 449 APInt signBitMask(nbits, 1); 450 signBitMask <<= nbits - 1; 451 452 APInt expBitsMask(nbits, 1); 453 expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; 454 455 if (destMantissaBits < static_cast<int>(srcMantissaBits)) { 456 // Last remaining mantissa bit. 457 APInt lastMantissaBitMask(nbits, 1); 458 lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; 459 460 // Compute rounding bias for round-to-nearest with ties to even. This is 461 // equal to a base value of 0111... plus one bit if the last remaining 462 // mantissa bit is 1. 463 APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; 464 465 Value mantissaDiff = b.create<arith::ConstantIntOp>( 466 srcMantissaBits - destMantissaBits, intType); 467 Value highestMantissaMaskVal = b.create<arith::ConstantIntOp>( 468 lastMantissaBitMask.getZExtValue(), intType); 469 Value baseRoundingBiasVal = b.create<arith::ConstantIntOp>( 470 baseRoundingBias.getZExtValue(), intType); 471 Value xLastMantissaBit = b.create<arith::ShRUIOp>( 472 b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff); 473 Value xRoundingBias = 474 b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal); 475 476 // Add rounding bias, and mask out truncated bits. Note that the case 477 // where adding the rounding bias overflows into the exponent bits is 478 // correct; the non-masked mantissa bits will all be zero, and the 479 // exponent will be incremented by one. 480 APInt truncationMask = ~(lastMantissaBitMask - 1); 481 Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias); 482 xRounded = b.create<arith::AndIOp>( 483 xRounded, 484 b.create<arith::ConstantIntOp>(truncationMask.getZExtValue(), intType) 485 .getResult()); 486 xAsInt = xRounded; 487 } 488 489 if (destExponentBits < srcExponentBits) { 490 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- 491 // significant bit -- is equal to 1.0f for all exponent sizes. Adding 492 // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- 493 // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' 494 // exponent (corresponding to 0.0f). 495 // 496 // Thus, the f32 exponent corresponding to the highest non-infinite 497 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 498 // exponent corresponding to the lowest exponent for a bit size of n is 499 // (2^7-1) - 2^(n-1)-1. 500 // 501 // Note that we have already checked that exponents_bits >= 1. 502 APInt exponentBias(nbits, 1); 503 exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; 504 505 APInt reducedExponentBias(nbits, 1); 506 reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; 507 508 APInt reducedMaxExponent = exponentBias + reducedExponentBias; 509 APInt reducedMinExponent = exponentBias - reducedExponentBias; 510 511 // Do we overflow or underflow? 512 Value xExponent = b.create<arith::AndIOp>( 513 xAsInt, 514 b.create<arith::ConstantIntOp>(expBitsMask.getZExtValue(), intType) 515 .getResult()); 516 Value xOverflows = b.create<arith::CmpIOp>( 517 arith::CmpIPredicate::ugt, xExponent, 518 b.create<arith::ConstantIntOp>( 519 (reducedMaxExponent << srcMantissaBits).getZExtValue(), intType) 520 .getResult()); 521 Value xUnderflows = b.create<arith::CmpIOp>( 522 arith::CmpIPredicate::ule, xExponent, 523 b.create<arith::ConstantIntOp>( 524 (reducedMinExponent << srcMantissaBits).getZExtValue(), intType) 525 .getResult()); 526 527 // Compute appropriately-signed values of zero and infinity. 528 Value xSignedZero = b.create<arith::AndIOp>( 529 xAsInt, 530 b.create<arith::ConstantIntOp>(signBitMask.getZExtValue(), intType) 531 .getResult()); 532 Value xSignedInf = b.create<arith::OrIOp>( 533 xSignedZero, 534 b.create<arith::ConstantIntOp>(expBitsMask.getZExtValue(), intType) 535 .getResult()); 536 537 // Force to zero or infinity if overflow or underflow. (Note that this 538 // truncates all denormal values to zero, rather than rounding them.) 539 xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt); 540 xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt); 541 } 542 543 return b.create<arith::BitcastOp>(floatType, xAsInt); 544 } 545 546 template <> 547 inline Value mapMhloOpToStdScalarOp<mhlo::CopyOp>( 548 Location /*loc*/, ArrayRef<Type> /*ResultTypes*/, 549 ArrayRef<Type> /*argTypes*/, ValueRange args, OpBuilder* /*b*/) { 550 return args.front(); 551 } 552 553 template <> 554 inline Value mapMhloOpToStdScalarOp<mhlo::ComplexOp>(Location loc, 555 ArrayRef<Type> resultTypes, 556 ArrayRef<Type> argTypes, 557 ValueRange args, 558 OpBuilder* b) { 559 return MapMhloOpToScalarOpImpl<complex::CreateOp>{}(loc, resultTypes, 560 argTypes, args, b); 561 } 562 563 template <> 564 inline Value mapMhloOpToStdScalarOp<mhlo::RealOp>(Location loc, 565 ArrayRef<Type> resultTypes, 566 ArrayRef<Type> argTypes, 567 ValueRange args, 568 OpBuilder* b) { 569 if (!args[0].getType().isa<ComplexType>()) return args[0]; 570 return MapMhloOpToScalarOpImpl<complex::ReOp>{}(loc, resultTypes, argTypes, 571 args, b); 572 } 573 574 template <> 575 inline Value mapMhloOpToStdScalarOp<mhlo::ImagOp>(Location loc, 576 ArrayRef<Type> resultTypes, 577 ArrayRef<Type> argTypes, 578 ValueRange args, 579 OpBuilder* b) { 580 if (!args[0].getType().isa<ComplexType>()) 581 return b->create<arith::ConstantOp>(loc, b->getZeroAttr(args[0].getType())); 582 return MapMhloOpToScalarOpImpl<complex::ImOp>{}(loc, resultTypes, argTypes, 583 args, b); 584 } 585 586 // 'target_types' is the unconverted type (signed or unsigned if integer), 587 // 'ResultTypes' is the converted type (signless if integer). 588 inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef<Type> targetTypes, 589 ArrayRef<Type> resultTypes, 590 ArrayRef<Type> argTypes, ValueRange args, 591 OpBuilder* b) { 592 assert(targetTypes.size() == 1 && "ConvertOp should return a single result"); 593 assert(resultTypes.size() == 1 && "ConvertOp should return a single result"); 594 assert(argTypes.size() == 1 && "ConvertOp should take a single argument"); 595 assert(args.size() == 1 && "ConvertOp should take a single argument"); 596 597 Type sourceType = getElementTypeOrSelf(argTypes.front()); 598 Type targetType = getElementTypeOrSelf(targetTypes.front()); 599 Type convertedSourceType = getElementTypeOrSelf(args.front()); 600 601 // A boolean value is considered to be unsigned when converting to 602 // floating-point. Otherwise, it will become `-1`. 603 if (IsUnsignedIntegerType{}(sourceType) && 604 mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType, 605 targetType)) { 606 return b->create<mlir::arith::UIToFPOp>(loc, resultTypes, args, mlir::None); 607 } 608 if (mlir::arith::SIToFPOp::areCastCompatible(sourceType, targetType)) { 609 return b->create<mlir::arith::SIToFPOp>(loc, resultTypes, args, mlir::None); 610 } 611 if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) { 612 auto src = sourceType.cast<FloatType>(); 613 auto res = targetType.cast<FloatType>(); 614 if (src.getWidth() > res.getWidth()) { 615 return b->create<mlir::arith::TruncFOp>(loc, resultTypes, args, 616 mlir::None); 617 } 618 if (src.getWidth() < res.getWidth()) { 619 return b->create<mlir::arith::ExtFOp>(loc, resultTypes, args, mlir::None); 620 } 621 // There's no direct conversion between different 16 bit floating point 622 // types, so go through 32 bit float. 623 if (sourceType != targetType) { 624 assert(sourceType.isBF16() || targetType.isBF16()); 625 Value ext = b->create<arith::ExtFOp>(loc, b->getF32Type(), args); 626 return b->create<arith::TruncFOp>(loc, resultTypes, ext); 627 } 628 // No conversion is needed for identical float types. 629 return args.front(); 630 } 631 if (targetType.isInteger(/*width=*/1)) { 632 // When casting to bool, we need to compare whether the value is equal to 633 // zero. 634 if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) { 635 Value zeroIntval = b->create<arith::ConstantOp>( 636 loc, b->getZeroAttr(args.front().getType())); 637 return b->create<mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::ne, 638 args.front(), zeroIntval); 639 } 640 if (sourceType.isa<FloatType>()) { 641 Value zero = b->create<arith::ConstantOp>( 642 loc, b->getZeroAttr(args.front().getType())); 643 return b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, 644 args.front(), zero); 645 } 646 } 647 if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) { 648 auto src = sourceType.cast<IntegerType>(); 649 auto res = targetType.cast<IntegerType>(); 650 if (src.getWidth() > res.getWidth()) { 651 return b->create<mlir::arith::TruncIOp>(loc, resultTypes, args, 652 mlir::None); 653 } 654 if (src.getWidth() < res.getWidth()) { 655 // Special case boolean values, so they get casted to `1` instead of `-1`. 656 if (IsUnsignedIntegerType{}(src)) { 657 return b->create<mlir::arith::ExtUIOp>(loc, resultTypes, args, 658 mlir::None); 659 } 660 return b->create<mlir::arith::ExtSIOp>(loc, resultTypes, args, 661 mlir::None); 662 } 663 // No conversion is needed for the same width integers 664 return args.front(); 665 } 666 if (targetType.isUnsignedInteger() && 667 mlir::arith::FPToUIOp::areCastCompatible(convertedSourceType, 668 targetType)) { 669 return b->create<mlir::arith::FPToUIOp>(loc, resultTypes, args, mlir::None); 670 } 671 if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType, 672 targetType)) { 673 return b->create<mlir::arith::FPToSIOp>(loc, resultTypes, args, mlir::None); 674 } 675 if (targetType.isa<ComplexType>()) { 676 Type targetElementType = targetType.cast<ComplexType>().getElementType(); 677 assert(!targetElementType.isa<ComplexType>() && 678 "elements of complex numbers should not be complex"); 679 Value targetReal; 680 Value targetImag; 681 if (sourceType.isa<ComplexType>()) { 682 // We are converting from complex type: convert real and imaginary parts 683 // separately. 684 Type sourceElementType = sourceType.cast<ComplexType>().getElementType(); 685 assert(!sourceElementType.isa<ComplexType>() && 686 "elements of complex numbers should not be complex"); 687 Value sourceReal = 688 b->create<mlir::complex::ReOp>(loc, sourceElementType, args.front()); 689 targetReal = 690 mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, 691 sourceElementType, sourceReal, b); 692 Value sourceImag = 693 b->create<mlir::complex::ImOp>(loc, sourceElementType, args.front()); 694 targetImag = 695 mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, 696 sourceElementType, sourceImag, b); 697 } else { 698 // We are converting from real (float, integer, etc.) type, convert the 699 // real part and set the imaginary part to 0. 700 targetReal = mapConvertOpToStdScalarOp( 701 loc, targetElementType, targetElementType, argTypes, args, b); 702 targetImag = b->create<mlir::arith::ConstantOp>( 703 loc, b->getFloatAttr(targetElementType, 0.0)); 704 } 705 return b->create<mlir::complex::CreateOp>(loc, targetType, targetReal, 706 targetImag); 707 } 708 if (auto sourceComplexType = sourceType.dyn_cast<ComplexType>()) { 709 auto sourceElementType = sourceComplexType.getElementType(); 710 // When converting from complex to a non-complex type, we take just the real 711 // part of the complex number. 712 Value sourceReal = 713 b->create<mlir::complex::ReOp>(loc, sourceElementType, args.front()); 714 return mapConvertOpToStdScalarOp(loc, targetTypes, resultTypes, 715 sourceElementType, sourceReal, b); 716 } 717 return nullptr; 718 } 719 720 template <> 721 inline Value mapMhloOpToStdScalarOp<mhlo::BitcastConvertOp>( 722 Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type>, ValueRange args, 723 OpBuilder* b) { 724 return b->create<mlir::arith::BitcastOp>(loc, resultTypes, args); 725 } 726 727 template <> 728 inline Value mapMhloOpToStdScalarOp<mhlo::DotOp>(Location loc, 729 ArrayRef<Type> resultTypes, 730 ArrayRef<Type> argTypes, 731 ValueRange args, 732 OpBuilder* b) { 733 // Dot Op converter from lhlo to affine only accepts float and integer types. 734 const auto& lhs = args[0]; 735 const auto& rhs = args[1]; 736 const auto& result = args[2]; 737 Type elementType = lhs.getType(); 738 if (elementType.isa<FloatType>()) { 739 Value floatMul = 740 MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::arith::MulFOp>{}( 741 loc, resultTypes, argTypes, {lhs, rhs}, b); 742 return MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::arith::AddFOp>{}( 743 loc, resultTypes, argTypes, {floatMul, result}, b); 744 } 745 if (elementType.isa<IntegerType>()) { 746 Value intMul = 747 MapMhloOpToScalarOpImpl<IsAnyIntegerType, ::mlir::arith::MulIOp>{}( 748 loc, resultTypes, argTypes, {lhs, rhs}, b); 749 return MapMhloOpToScalarOpImpl<IsAnyIntegerType, ::mlir::arith::AddIOp>{}( 750 loc, resultTypes, argTypes, {intMul, result}, b); 751 } 752 return nullptr; 753 } 754 755 template <> 756 inline Value mapMhloOpToStdScalarOp<mhlo::IsFiniteOp>( 757 Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, 758 ValueRange args, OpBuilder* b) { 759 if (args[0].getType().isa<FloatType>()) { 760 auto posInf = APFloat::getInf( 761 args[0].getType().cast<FloatType>().getFloatSemantics()); 762 auto constPosInf = b->create<arith::ConstantOp>( 763 loc, b->getFloatAttr(args[0].getType(), posInf)); 764 Value absX = b->create<::mlir::math::AbsFOp>(loc, args[0]); 765 return b->create<::mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, 766 absX, constPosInf); 767 } 768 return nullptr; 769 } 770 771 /// Implements the conversion of HLO op to scalar op (to use within region of a 772 /// linalg.generic op) for compare-select style operations like min/max. 773 template <typename... Args> 774 struct CompareSelectOpToStdScalarOp { 775 static Value map(Location /*loc*/, 776 ComparisonDirection /*comparison_direction*/, 777 ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, 778 ValueRange /*args*/, OpBuilder* /*b*/) { 779 return nullptr; 780 } 781 }; 782 783 /// Specialization which allows converting to a comparison operation in standard 784 /// dialect with a given predicate based on the element type of the operand. 785 template <typename SupportedType, typename StdCompareOp, typename Predicate, 786 typename... Args> 787 struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate, 788 Args...> { 789 static Value map(Location loc, ComparisonDirection comparisonDirection, 790 ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, 791 ValueRange args, OpBuilder* b) { 792 Type elementType = getElementTypeOrSelf(argTypes.front()); 793 if (elementType.isa<SupportedType>()) { 794 auto predicate = getCmpPredicate<Predicate>( 795 comparisonDirection, !elementType.isUnsignedInteger()); 796 assert(predicate.has_value() && "expected valid comparison direction"); 797 auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(), 798 args[0], args[1]); 799 return b->create<::mlir::arith::SelectOp>(loc, cmp, args[0], args[1]); 800 } 801 return CompareSelectOpToStdScalarOp<Args...>::map( 802 loc, comparisonDirection, resultTypes, argTypes, args, b); 803 } 804 }; 805 806 inline Value mhloAlwaysPropagateNaN(Value v, ValueRange args, Location loc, 807 OpBuilder* b) { 808 Type elementType = getElementTypeOrSelf(args.front().getType()); 809 if (auto floatType = elementType.dyn_cast<FloatType>()) { 810 Value isnan = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, 811 args[0], args[1]); 812 813 auto nanApfloat = APFloat::getQNaN(floatType.getFloatSemantics()); 814 Value nan = getConstantOrSplat(b, loc, args[0].getType(), 815 b->getFloatAttr(floatType, nanApfloat)); 816 v = b->create<mlir::arith::SelectOp>(loc, isnan, nan, v); 817 } 818 return v; 819 } 820 821 template <> 822 inline Value mapMhloOpToStdScalarOp<mhlo::LogisticOp>( 823 Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> /*argTypes*/, 824 ValueRange args, OpBuilder* b) { 825 auto ty = resultTypes.front().cast<FloatType>(); 826 Value one = b->create<arith::ConstantOp>(loc, b->getFloatAttr(ty, 1.0)); 827 Value x = args.front(); 828 Value negX = b->create<arith::NegFOp>(loc, x); 829 Value expNegX = b->create<::mlir::math::ExpOp>(loc, negX); 830 Value oneAddExpNegX = b->create<arith::AddFOp>(loc, one, expNegX); 831 return b->create<arith::DivFOp>(loc, one, oneAddExpNegX); 832 } 833 834 template <> 835 inline Value mapMhloOpToStdScalarOp<mhlo::ClampOp>(Location loc, 836 ArrayRef<Type> resultTypes, 837 ArrayRef<Type> argTypes, 838 ValueRange args, 839 OpBuilder* b) { 840 mhlo::ClampOp::Adaptor op(args); 841 // clamp(lb, x, ub) = min(max(lb, x), ub) 842 Value maxLbX = mapMhloOpToStdScalarOp<mhlo::MaxOp>( 843 loc, resultTypes, argTypes, {op.min(), op.operand()}, b); 844 return mapMhloOpToStdScalarOp<mhlo::MinOp>(loc, resultTypes, argTypes, 845 {maxLbX, op.max()}, b); 846 } 847 848 template <typename U, typename S> 849 inline Value makeSafeIntDiv(ImplicitLocOpBuilder& lb, Type originalType, 850 Value lhs, Value rhs, Value returnedOnZero, 851 Value returnedOnSignedOverflow) { 852 Type type = lhs.getType(); 853 auto elementType = getElementTypeOrSelf(type).cast<IntegerType>(); 854 Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type)); 855 auto makeConstant = [&](const APInt& i) { 856 return getConstantOrSplat(&lb, lb.getLoc(), type, 857 lb.getIntegerAttr(elementType, i)); 858 }; 859 Value one = makeConstant(APInt(elementType.getWidth(), 1)); 860 Value rhsIsZero = 861 lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, rhs, zero); 862 863 // For unsigned just set the divisor to 1 when it would be 0. 864 if (originalType.isUnsignedInteger()) { 865 Value safeRhs = lb.create<arith::SelectOp>(rhsIsZero, one, rhs); 866 Value safeDiv = lb.create<U>(lhs, safeRhs); 867 return lb.create<arith::SelectOp>(rhsIsZero, returnedOnZero, safeDiv); 868 } 869 870 // For signed also check for INT_MIN / -1. 871 Value smin = makeConstant(APInt::getSignedMinValue(elementType.getWidth())); 872 Value lhsIsSmin = 873 lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, lhs, smin); 874 Value minusOne = makeConstant(APInt::getAllOnesValue(elementType.getWidth())); 875 Value rhsIsMinusOne = 876 lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, rhs, minusOne); 877 Value hasIntMinOverflow = lb.create<arith::AndIOp>(lhsIsSmin, rhsIsMinusOne); 878 Value rhsIsUnsafe = lb.create<arith::OrIOp>(rhsIsZero, hasIntMinOverflow); 879 Value safeRhs = lb.create<arith::SelectOp>(rhsIsUnsafe, one, rhs); 880 Value safeDiv = lb.create<S>(lhs, safeRhs); 881 Value safeSmin = lb.create<arith::SelectOp>( 882 hasIntMinOverflow, returnedOnSignedOverflow, safeDiv); 883 return lb.create<arith::SelectOp>(rhsIsZero, returnedOnZero, safeSmin); 884 } 885 886 template <> 887 inline Value mapMhloOpToStdScalarOp<mhlo::DivOp>(Location loc, 888 ArrayRef<Type> resultTypes, 889 ArrayRef<Type> argTypes, 890 ValueRange args, 891 OpBuilder* b) { 892 Type originalType = getElementTypeOrSelf(argTypes.front()); 893 if (originalType.isa<ComplexType, FloatType>()) { 894 return MapMhloOpToScalarOpImpl<IsFloatType, arith::DivFOp, IsComplexType, 895 complex::DivOp>{}(loc, resultTypes, argTypes, 896 args, b); 897 } 898 899 // Integer division overflow behavior: 900 // 901 // X / 0 == -1 902 // INT_SMIN /s -1 = INT_SMIN 903 ImplicitLocOpBuilder lb(loc, *b); 904 Type type = args.front().getType(); 905 auto elementType = getElementTypeOrSelf(type).cast<IntegerType>(); 906 auto makeConstant = [&](const APInt& i) { 907 return getConstantOrSplat(&lb, lb.getLoc(), type, 908 lb.getIntegerAttr(elementType, i)); 909 }; 910 Value minusOne = makeConstant(APInt::getAllOnesValue(elementType.getWidth())); 911 Value smin = makeConstant(APInt::getSignedMinValue(elementType.getWidth())); 912 return makeSafeIntDiv<arith::DivUIOp, arith::DivSIOp>( 913 lb, originalType, args[0], args[1], /*returnedOnZero=*/minusOne, 914 /*returnedOnSignedOverflow=*/smin); 915 } 916 917 template <> 918 inline Value mapMhloOpToStdScalarOp<mhlo::RemOp>(Location loc, 919 ArrayRef<Type> resultTypes, 920 ArrayRef<Type> argTypes, 921 ValueRange args, 922 OpBuilder* b) { 923 Type originalType = getElementTypeOrSelf(argTypes.front()); 924 if (originalType.isa<ComplexType, FloatType>()) { 925 return MapMhloOpToScalarOpImpl<IsFloatType, arith::RemFOp>{}( 926 loc, resultTypes, argTypes, args, b); 927 } 928 929 // Integer remainder overflow behavior: 930 // 931 // X % 0 == X 932 // INT_SMIN %s -1 = 0 933 ImplicitLocOpBuilder lb(loc, *b); 934 Type type = args.front().getType(); 935 Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type)); 936 return makeSafeIntDiv<arith::RemUIOp, arith::RemSIOp>( 937 lb, originalType, args[0], args[1], /*returnedOnZero=*/args[0], 938 /*returnedOnSignedOverflow=*/zero); 939 } 940 941 template <> 942 inline Value mapMhloOpToStdScalarOp<mhlo::NegOp>(Location loc, 943 ArrayRef<Type> resultTypes, 944 ArrayRef<Type> argTypes, 945 ValueRange args, 946 OpBuilder* b) { 947 Type elementType = getElementTypeOrSelf(args.front().getType()); 948 if (elementType.isa<ComplexType, FloatType>()) { 949 return MapMhloOpToScalarOpImpl<IsFloatType, ::mlir::arith::NegFOp, 950 IsComplexType, ::mlir::complex::NegOp>{}( 951 loc, resultTypes, argTypes, args, b); 952 } 953 if (elementType.isa<IntegerType>()) { 954 // lmhlo.neg(x, result) -> result = sub(0, x) 955 Value lhs = args[0]; 956 Value zeroIntval = 957 b->create<arith::ConstantOp>(loc, b->getZeroAttr(lhs.getType())); 958 return b->create<ScalarIOp<mhlo::SubtractOp>>(loc, zeroIntval, lhs); 959 } 960 return nullptr; 961 } 962 963 template <> 964 inline Value mapMhloOpToStdScalarOp<mhlo::NotOp>(Location loc, 965 ArrayRef<Type> /*ResultTypes*/, 966 ArrayRef<Type> /*argTypes*/, 967 ValueRange args, 968 OpBuilder* b) { 969 Type elementType = getElementTypeOrSelf(args.front().getType()); 970 if (auto integerType = elementType.dyn_cast<IntegerType>()) { 971 // lmhlo.not(x) -> x ^ -1 972 Value allOnes = getConstantOrSplat( 973 b, loc, args[0].getType(), 974 b->getIntegerAttr(integerType, 975 APInt::getAllOnesValue(integerType.getWidth()))); 976 return b->create<::mlir::arith::XOrIOp>(loc, allOnes, args[0]); 977 } 978 return nullptr; 979 } 980 981 template <> 982 inline Value mapMhloOpToStdScalarOp<mhlo::PowOp>(Location loc, 983 ArrayRef<Type> resultTypes, 984 ArrayRef<Type> argTypes, 985 ValueRange args, 986 OpBuilder* b) { 987 mhlo::PowOp::Adaptor adaptor(args); 988 auto lb = ImplicitLocOpBuilder(loc, *b); 989 // Floating point can use std::powf 990 auto resultType = resultTypes.front(); 991 if (resultType.isa<ComplexType, FloatType>()) { 992 return MapMhloOpToScalarOpImpl<IsFloatType, math::PowFOp, IsComplexType, 993 complex::PowOp>{}(loc, resultTypes, argTypes, 994 args, b); 995 } 996 997 // Exponentiation by squaring: 998 // https://en.wikipedia.org/wiki/Exponentiation_by_squaring; 999 Value negOne = 1000 lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, -1)); 1001 Value zero = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 0)); 1002 Value one = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 1)); 1003 Value two = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 2)); 1004 Value step = lb.create<arith::ConstantIndexOp>(1); 1005 Value lowerBound = lb.create<arith::ConstantIndexOp>(0); 1006 // Everything else would overflow for any exponent > 1, as 2^64 1007 // is the larget possible exponent for a 64-bit integer, and 1008 // that's 1 << 6. 1009 Value upperBound = lb.create<arith::ConstantIndexOp>(6); 1010 auto originalBase = adaptor.lhs(); 1011 auto originalExponent = adaptor.rhs(); 1012 1013 Value accum = 1014 lb.create<scf::ForOp>( 1015 lowerBound, upperBound, step, 1016 SmallVector<Value>({one, originalBase, originalExponent}), 1017 [&](OpBuilder& b, Location, Value /*v*/, ValueRange iters) { 1018 Value accum = iters[0]; 1019 Value base = iters[1]; 1020 Value exponent = iters[2]; 1021 1022 Value condition = b.create<arith::CmpIOp>( 1023 loc, arith::CmpIPredicate::eq, 1024 b.create<::mlir::arith::AndIOp>(loc, exponent, one), one); 1025 Value multiplied = 1026 b.create<::mlir::arith::MulIOp>(loc, accum, base); 1027 accum = b.create<::mlir::arith::SelectOp>(loc, condition, 1028 multiplied, accum); 1029 base = b.create<::mlir::arith::MulIOp>(loc, base, base); 1030 exponent = b.create<::mlir::arith::ShRUIOp>(loc, exponent, one); 1031 b.create<scf::YieldOp>( 1032 loc, SmallVector<Value>({accum, base, exponent})); 1033 }) 1034 .getResult(0); 1035 1036 Value rhsIsEven = lb.create<arith::CmpIOp>( 1037 arith::CmpIPredicate::eq, lb.create<arith::RemSIOp>(adaptor.rhs(), two), 1038 zero); 1039 Value rhsIsNegative = 1040 lb.create<arith::CmpIOp>(arith::CmpIPredicate::slt, adaptor.rhs(), zero); 1041 Value lhsIsOne = 1042 lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, adaptor.lhs(), one); 1043 Value lhsIsNegOne = 1044 lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, adaptor.lhs(), negOne); 1045 1046 // The accum is correct when the rhs is non-negative. When rhs is 1047 // negative, we return 0 for integer, with the exception of lhs values of 1 1048 // and -1 which have integer results for negative exponents. Specifically, the 1049 // calulation is the following: 1050 // 1051 // - Return accum if the rhs is not negative. 1052 // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. 1053 // - Return 1 if lhs is 1. 1054 // - Else return 0. 1055 Value ifLhsIsOne = lb.create<::mlir::arith::SelectOp>(lhsIsOne, one, zero); 1056 Value ifLhsIsNegOne = lb.create<::mlir::arith::SelectOp>( 1057 lhsIsNegOne, lb.create<::mlir::arith::SelectOp>(rhsIsEven, one, negOne), 1058 ifLhsIsOne); 1059 return lb.create<::mlir::arith::SelectOp>(rhsIsNegative, ifLhsIsNegOne, 1060 accum); 1061 } 1062 1063 template <> 1064 inline Value mapMhloOpToStdScalarOp<mhlo::SelectOp>(Location loc, 1065 ArrayRef<Type> resultTypes, 1066 ArrayRef<Type> argTypes, 1067 ValueRange args, 1068 OpBuilder* b) { 1069 return MapMhloOpToScalarOpImpl<::mlir::arith::SelectOp>{}(loc, resultTypes, 1070 argTypes, args, b); 1071 } 1072 1073 template <> 1074 inline Value mapMhloOpToStdScalarOp<mhlo::SignOp>(Location loc, 1075 ArrayRef<Type> resultTypes, 1076 ArrayRef<Type> /*argTypes*/, 1077 ValueRange args, 1078 OpBuilder* b) { 1079 Type elementType = getElementTypeOrSelf(args.front().getType()); 1080 if (auto floatType = elementType.dyn_cast<FloatType>()) { 1081 Value zero = 1082 b->create<arith::ConstantOp>(loc, b->getZeroAttr(args[0].getType())); 1083 Value ne0I1 = b->create<::mlir::arith::CmpFOp>( 1084 loc, arith::CmpFPredicate::ONE, args[0], zero); 1085 Value ne0Float = 1086 b->create<::mlir::arith::UIToFPOp>(loc, zero.getType(), ne0I1); 1087 Value copySign = b->create<::mlir::math::CopySignOp>(loc, resultTypes, 1088 ne0Float, args[0]); 1089 auto isNan = b->create<::mlir::arith::CmpFOp>( 1090 loc, arith::CmpFPredicate::UNO, args[0], args[0]); 1091 return b->create<::mlir::arith::SelectOp>(loc, isNan, args[0], copySign); 1092 } 1093 if (auto integerType = elementType.dyn_cast<IntegerType>()) { 1094 // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) 1095 Value zero = 1096 b->create<arith::ConstantOp>(loc, b->getZeroAttr(args[0].getType())); 1097 Value bitwidthMinusOne = getConstantOrSplat( 1098 b, loc, args[0].getType(), 1099 b->getIntegerAttr(integerType, integerType.getWidth() - 1)); 1100 Value one = getConstantOrSplat(b, loc, args[0].getType(), 1101 b->getIntegerAttr(integerType, 1)); 1102 Value cmp = b->create<::mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1103 args[0], zero); 1104 Value ashr = 1105 b->create<::mlir::arith::ShRSIOp>(loc, args[0], bitwidthMinusOne); 1106 Value orOp = b->create<::mlir::arith::OrIOp>(loc, ashr, one); 1107 return b->create<::mlir::arith::SelectOp>(loc, cmp, zero, orOp); 1108 } 1109 if (elementType.isa<ComplexType>()) { 1110 return b->create<::mlir::complex::SignOp>(loc, elementType, args.front()); 1111 } 1112 return nullptr; 1113 } 1114 1115 } // namespace impl 1116 1117 struct MhloOpToStdScalarOp { 1118 // Converts mhlo 'op' to linalg and arith ops. 1119 template <typename MhloOpTy> 1120 static Value mapOp(MhloOpTy op, ArrayRef<Type> resultTypes, ValueRange args, 1121 OpBuilder* b) { 1122 auto argTypes = llvm::to_vector(op->getOperandTypes()); 1123 return mapOpWithArgTypes(op, resultTypes, argTypes, args, b); 1124 } 1125 1126 // Converts mhlo 'op' to linalg and arith ops. The types of 'args' may already 1127 // be converted, 'argTypes' are their original types. 1128 template <typename MhloOpTy> 1129 static Value mapOpWithArgTypes(MhloOpTy op, ArrayRef<Type> resultTypes, 1130 ArrayRef<Type> argTypes, ValueRange args, 1131 OpBuilder* b) { 1132 static_assert(!std::is_same<MhloOpTy, mhlo::ConvertOp>::value); 1133 return mapOpOfType<MhloOpTy>(op.getLoc(), resultTypes, argTypes, args, b); 1134 } 1135 // Overload for mhlo::ReducePrecisionOp. 1136 static Value mapOpWithArgTypes(mhlo::ReducePrecisionOp op, 1137 ArrayRef<Type> /*ResultTypes*/, 1138 ArrayRef<Type> argTypes, ValueRange args, 1139 OpBuilder* b) { 1140 return impl::mapReducePrecisionOpToStdScalarOp( 1141 op.getLoc(), argTypes, args, b, op.exponent_bits(), op.mantissa_bits()); 1142 } 1143 // Overload for mhlo::CompareOp. 1144 static Value mapOpWithArgTypes(mhlo::CompareOp op, ArrayRef<Type> resultTypes, 1145 ArrayRef<Type> argTypes, ValueRange args, 1146 OpBuilder* b) { 1147 auto comparisonDirection = op.comparison_direction(); 1148 return impl::mapCompareOpToStdScalarOp(op.getLoc(), comparisonDirection, 1149 resultTypes, argTypes, args, b); 1150 } 1151 // Overload for mhlo::ConvertOp. 1152 static Value mapOpWithArgTypes(mhlo::ConvertOp op, ArrayRef<Type> resultTypes, 1153 ArrayRef<Type> argTypes, ValueRange args, 1154 OpBuilder* b) { 1155 return impl::mapConvertOpToStdScalarOp(op.getLoc(), op.getType(), 1156 resultTypes, argTypes, args, b); 1157 } 1158 1159 // Converts mhlo 'op' (except mhlo::CompareOp) to linalg and arith ops. 1160 template <typename MhloOpTy> 1161 static Value mapOpOfType(Location loc, ArrayRef<Type> resultTypes, 1162 ArrayRef<Type> argTypes, ValueRange args, 1163 OpBuilder* b) { 1164 static_assert(!std::is_same<MhloOpTy, mhlo::CompareOp>::value, "invalid"); 1165 if (std::is_same<MhloOpTy, mhlo::ConvertOp>::value) { 1166 // Note: this assumes that the caller is passing result/arg types with 1167 // appropriate signedness. 1168 return impl::mapConvertOpToStdScalarOp(loc, resultTypes, resultTypes, 1169 argTypes, args, b); 1170 } 1171 return impl::mapMhloOpToStdScalarOp<MhloOpTy>(loc, resultTypes, argTypes, 1172 args, b); 1173 } 1174 }; 1175 1176 } // namespace mhlo 1177 } // namespace mlir 1178 1179 #endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H 1180