xref: /aosp_15_r20/external/swiftshader/third_party/llvm-16.0/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 //===-- TargetLowering.cpp - Implement the TargetLowering class -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the TargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/CodeGen/TargetLowering.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Analysis/VectorUtils.h"
16 #include "llvm/CodeGen/CallingConvLower.h"
17 #include "llvm/CodeGen/CodeGenCommonISel.h"
18 #include "llvm/CodeGen/MachineFrameInfo.h"
19 #include "llvm/CodeGen/MachineFunction.h"
20 #include "llvm/CodeGen/MachineJumpTableInfo.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/CodeGen/SelectionDAG.h"
23 #include "llvm/CodeGen/TargetRegisterInfo.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/GlobalVariable.h"
27 #include "llvm/IR/LLVMContext.h"
28 #include "llvm/MC/MCAsmInfo.h"
29 #include "llvm/MC/MCExpr.h"
30 #include "llvm/Support/DivisionByConstantInfo.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "llvm/Support/KnownBits.h"
33 #include "llvm/Support/MathExtras.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include <cctype>
36 using namespace llvm;
37 
38 /// NOTE: The TargetMachine owns TLOF.
TargetLowering(const TargetMachine & tm)39 TargetLowering::TargetLowering(const TargetMachine &tm)
40     : TargetLoweringBase(tm) {}
41 
getTargetNodeName(unsigned Opcode) const42 const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
43   return nullptr;
44 }
45 
isPositionIndependent() const46 bool TargetLowering::isPositionIndependent() const {
47   return getTargetMachine().isPositionIndependent();
48 }
49 
50 /// Check whether a given call node is in tail position within its function. If
51 /// so, it sets Chain to the input chain of the tail call.
isInTailCallPosition(SelectionDAG & DAG,SDNode * Node,SDValue & Chain) const52 bool TargetLowering::isInTailCallPosition(SelectionDAG &DAG, SDNode *Node,
53                                           SDValue &Chain) const {
54   const Function &F = DAG.getMachineFunction().getFunction();
55 
56   // First, check if tail calls have been disabled in this function.
57   if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
58     return false;
59 
60   // Conservatively require the attributes of the call to match those of
61   // the return. Ignore following attributes because they don't affect the
62   // call sequence.
63   AttrBuilder CallerAttrs(F.getContext(), F.getAttributes().getRetAttrs());
64   for (const auto &Attr : {Attribute::Alignment, Attribute::Dereferenceable,
65                            Attribute::DereferenceableOrNull, Attribute::NoAlias,
66                            Attribute::NonNull, Attribute::NoUndef})
67     CallerAttrs.removeAttribute(Attr);
68 
69   if (CallerAttrs.hasAttributes())
70     return false;
71 
72   // It's not safe to eliminate the sign / zero extension of the return value.
73   if (CallerAttrs.contains(Attribute::ZExt) ||
74       CallerAttrs.contains(Attribute::SExt))
75     return false;
76 
77   // Check if the only use is a function return node.
78   return isUsedByReturnOnly(Node, Chain);
79 }
80 
parametersInCSRMatch(const MachineRegisterInfo & MRI,const uint32_t * CallerPreservedMask,const SmallVectorImpl<CCValAssign> & ArgLocs,const SmallVectorImpl<SDValue> & OutVals) const81 bool TargetLowering::parametersInCSRMatch(const MachineRegisterInfo &MRI,
82     const uint32_t *CallerPreservedMask,
83     const SmallVectorImpl<CCValAssign> &ArgLocs,
84     const SmallVectorImpl<SDValue> &OutVals) const {
85   for (unsigned I = 0, E = ArgLocs.size(); I != E; ++I) {
86     const CCValAssign &ArgLoc = ArgLocs[I];
87     if (!ArgLoc.isRegLoc())
88       continue;
89     MCRegister Reg = ArgLoc.getLocReg();
90     // Only look at callee saved registers.
91     if (MachineOperand::clobbersPhysReg(CallerPreservedMask, Reg))
92       continue;
93     // Check that we pass the value used for the caller.
94     // (We look for a CopyFromReg reading a virtual register that is used
95     //  for the function live-in value of register Reg)
96     SDValue Value = OutVals[I];
97     if (Value->getOpcode() == ISD::AssertZext)
98       Value = Value.getOperand(0);
99     if (Value->getOpcode() != ISD::CopyFromReg)
100       return false;
101     Register ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg();
102     if (MRI.getLiveInPhysReg(ArgReg) != Reg)
103       return false;
104   }
105   return true;
106 }
107 
108 /// Set CallLoweringInfo attribute flags based on a call instruction
109 /// and called function attributes.
setAttributes(const CallBase * Call,unsigned ArgIdx)110 void TargetLoweringBase::ArgListEntry::setAttributes(const CallBase *Call,
111                                                      unsigned ArgIdx) {
112   IsSExt = Call->paramHasAttr(ArgIdx, Attribute::SExt);
113   IsZExt = Call->paramHasAttr(ArgIdx, Attribute::ZExt);
114   IsInReg = Call->paramHasAttr(ArgIdx, Attribute::InReg);
115   IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet);
116   IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest);
117   IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal);
118   IsPreallocated = Call->paramHasAttr(ArgIdx, Attribute::Preallocated);
119   IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca);
120   IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned);
121   IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf);
122   IsSwiftAsync = Call->paramHasAttr(ArgIdx, Attribute::SwiftAsync);
123   IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError);
124   Alignment = Call->getParamStackAlign(ArgIdx);
125   IndirectType = nullptr;
126   assert(IsByVal + IsPreallocated + IsInAlloca + IsSRet <= 1 &&
127          "multiple ABI attributes?");
128   if (IsByVal) {
129     IndirectType = Call->getParamByValType(ArgIdx);
130     if (!Alignment)
131       Alignment = Call->getParamAlign(ArgIdx);
132   }
133   if (IsPreallocated)
134     IndirectType = Call->getParamPreallocatedType(ArgIdx);
135   if (IsInAlloca)
136     IndirectType = Call->getParamInAllocaType(ArgIdx);
137   if (IsSRet)
138     IndirectType = Call->getParamStructRetType(ArgIdx);
139 }
140 
141 /// Generate a libcall taking the given operands as arguments and returning a
142 /// result of type RetVT.
143 std::pair<SDValue, SDValue>
makeLibCall(SelectionDAG & DAG,RTLIB::Libcall LC,EVT RetVT,ArrayRef<SDValue> Ops,MakeLibCallOptions CallOptions,const SDLoc & dl,SDValue InChain) const144 TargetLowering::makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC, EVT RetVT,
145                             ArrayRef<SDValue> Ops,
146                             MakeLibCallOptions CallOptions,
147                             const SDLoc &dl,
148                             SDValue InChain) const {
149   if (!InChain)
150     InChain = DAG.getEntryNode();
151 
152   TargetLowering::ArgListTy Args;
153   Args.reserve(Ops.size());
154 
155   TargetLowering::ArgListEntry Entry;
156   for (unsigned i = 0; i < Ops.size(); ++i) {
157     SDValue NewOp = Ops[i];
158     Entry.Node = NewOp;
159     Entry.Ty = Entry.Node.getValueType().getTypeForEVT(*DAG.getContext());
160     Entry.IsSExt = shouldSignExtendTypeInLibCall(NewOp.getValueType(),
161                                                  CallOptions.IsSExt);
162     Entry.IsZExt = !Entry.IsSExt;
163 
164     if (CallOptions.IsSoften &&
165         !shouldExtendTypeInLibCall(CallOptions.OpsVTBeforeSoften[i])) {
166       Entry.IsSExt = Entry.IsZExt = false;
167     }
168     Args.push_back(Entry);
169   }
170 
171   if (LC == RTLIB::UNKNOWN_LIBCALL)
172     report_fatal_error("Unsupported library call operation!");
173   SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
174                                          getPointerTy(DAG.getDataLayout()));
175 
176   Type *RetTy = RetVT.getTypeForEVT(*DAG.getContext());
177   TargetLowering::CallLoweringInfo CLI(DAG);
178   bool signExtend = shouldSignExtendTypeInLibCall(RetVT, CallOptions.IsSExt);
179   bool zeroExtend = !signExtend;
180 
181   if (CallOptions.IsSoften &&
182       !shouldExtendTypeInLibCall(CallOptions.RetVTBeforeSoften)) {
183     signExtend = zeroExtend = false;
184   }
185 
186   CLI.setDebugLoc(dl)
187       .setChain(InChain)
188       .setLibCallee(getLibcallCallingConv(LC), RetTy, Callee, std::move(Args))
189       .setNoReturn(CallOptions.DoesNotReturn)
190       .setDiscardResult(!CallOptions.IsReturnValueUsed)
191       .setIsPostTypeLegalization(CallOptions.IsPostTypeLegalization)
192       .setSExtResult(signExtend)
193       .setZExtResult(zeroExtend);
194   return LowerCallTo(CLI);
195 }
196 
findOptimalMemOpLowering(std::vector<EVT> & MemOps,unsigned Limit,const MemOp & Op,unsigned DstAS,unsigned SrcAS,const AttributeList & FuncAttributes) const197 bool TargetLowering::findOptimalMemOpLowering(
198     std::vector<EVT> &MemOps, unsigned Limit, const MemOp &Op, unsigned DstAS,
199     unsigned SrcAS, const AttributeList &FuncAttributes) const {
200   if (Limit != ~unsigned(0) && Op.isMemcpyWithFixedDstAlign() &&
201       Op.getSrcAlign() < Op.getDstAlign())
202     return false;
203 
204   EVT VT = getOptimalMemOpType(Op, FuncAttributes);
205 
206   if (VT == MVT::Other) {
207     // Use the largest integer type whose alignment constraints are satisfied.
208     // We only need to check DstAlign here as SrcAlign is always greater or
209     // equal to DstAlign (or zero).
210     VT = MVT::i64;
211     if (Op.isFixedDstAlign())
212       while (Op.getDstAlign() < (VT.getSizeInBits() / 8) &&
213              !allowsMisalignedMemoryAccesses(VT, DstAS, Op.getDstAlign()))
214         VT = (MVT::SimpleValueType)(VT.getSimpleVT().SimpleTy - 1);
215     assert(VT.isInteger());
216 
217     // Find the largest legal integer type.
218     MVT LVT = MVT::i64;
219     while (!isTypeLegal(LVT))
220       LVT = (MVT::SimpleValueType)(LVT.SimpleTy - 1);
221     assert(LVT.isInteger());
222 
223     // If the type we've chosen is larger than the largest legal integer type
224     // then use that instead.
225     if (VT.bitsGT(LVT))
226       VT = LVT;
227   }
228 
229   unsigned NumMemOps = 0;
230   uint64_t Size = Op.size();
231   while (Size) {
232     unsigned VTSize = VT.getSizeInBits() / 8;
233     while (VTSize > Size) {
234       // For now, only use non-vector load / store's for the left-over pieces.
235       EVT NewVT = VT;
236       unsigned NewVTSize;
237 
238       bool Found = false;
239       if (VT.isVector() || VT.isFloatingPoint()) {
240         NewVT = (VT.getSizeInBits() > 64) ? MVT::i64 : MVT::i32;
241         if (isOperationLegalOrCustom(ISD::STORE, NewVT) &&
242             isSafeMemOpType(NewVT.getSimpleVT()))
243           Found = true;
244         else if (NewVT == MVT::i64 &&
245                  isOperationLegalOrCustom(ISD::STORE, MVT::f64) &&
246                  isSafeMemOpType(MVT::f64)) {
247           // i64 is usually not legal on 32-bit targets, but f64 may be.
248           NewVT = MVT::f64;
249           Found = true;
250         }
251       }
252 
253       if (!Found) {
254         do {
255           NewVT = (MVT::SimpleValueType)(NewVT.getSimpleVT().SimpleTy - 1);
256           if (NewVT == MVT::i8)
257             break;
258         } while (!isSafeMemOpType(NewVT.getSimpleVT()));
259       }
260       NewVTSize = NewVT.getSizeInBits() / 8;
261 
262       // If the new VT cannot cover all of the remaining bits, then consider
263       // issuing a (or a pair of) unaligned and overlapping load / store.
264       unsigned Fast;
265       if (NumMemOps && Op.allowOverlap() && NewVTSize < Size &&
266           allowsMisalignedMemoryAccesses(
267               VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
268               MachineMemOperand::MONone, &Fast) &&
269           Fast)
270         VTSize = Size;
271       else {
272         VT = NewVT;
273         VTSize = NewVTSize;
274       }
275     }
276 
277     if (++NumMemOps > Limit)
278       return false;
279 
280     MemOps.push_back(VT);
281     Size -= VTSize;
282   }
283 
284   return true;
285 }
286 
287 /// Soften the operands of a comparison. This code is shared among BR_CC,
288 /// SELECT_CC, and SETCC handlers.
softenSetCCOperands(SelectionDAG & DAG,EVT VT,SDValue & NewLHS,SDValue & NewRHS,ISD::CondCode & CCCode,const SDLoc & dl,const SDValue OldLHS,const SDValue OldRHS) const289 void TargetLowering::softenSetCCOperands(SelectionDAG &DAG, EVT VT,
290                                          SDValue &NewLHS, SDValue &NewRHS,
291                                          ISD::CondCode &CCCode,
292                                          const SDLoc &dl, const SDValue OldLHS,
293                                          const SDValue OldRHS) const {
294   SDValue Chain;
295   return softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, dl, OldLHS,
296                              OldRHS, Chain);
297 }
298 
softenSetCCOperands(SelectionDAG & DAG,EVT VT,SDValue & NewLHS,SDValue & NewRHS,ISD::CondCode & CCCode,const SDLoc & dl,const SDValue OldLHS,const SDValue OldRHS,SDValue & Chain,bool IsSignaling) const299 void TargetLowering::softenSetCCOperands(SelectionDAG &DAG, EVT VT,
300                                          SDValue &NewLHS, SDValue &NewRHS,
301                                          ISD::CondCode &CCCode,
302                                          const SDLoc &dl, const SDValue OldLHS,
303                                          const SDValue OldRHS,
304                                          SDValue &Chain,
305                                          bool IsSignaling) const {
306   // FIXME: Currently we cannot really respect all IEEE predicates due to libgcc
307   // not supporting it. We can update this code when libgcc provides such
308   // functions.
309 
310   assert((VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f128 || VT == MVT::ppcf128)
311          && "Unsupported setcc type!");
312 
313   // Expand into one or more soft-fp libcall(s).
314   RTLIB::Libcall LC1 = RTLIB::UNKNOWN_LIBCALL, LC2 = RTLIB::UNKNOWN_LIBCALL;
315   bool ShouldInvertCC = false;
316   switch (CCCode) {
317   case ISD::SETEQ:
318   case ISD::SETOEQ:
319     LC1 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
320           (VT == MVT::f64) ? RTLIB::OEQ_F64 :
321           (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
322     break;
323   case ISD::SETNE:
324   case ISD::SETUNE:
325     LC1 = (VT == MVT::f32) ? RTLIB::UNE_F32 :
326           (VT == MVT::f64) ? RTLIB::UNE_F64 :
327           (VT == MVT::f128) ? RTLIB::UNE_F128 : RTLIB::UNE_PPCF128;
328     break;
329   case ISD::SETGE:
330   case ISD::SETOGE:
331     LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
332           (VT == MVT::f64) ? RTLIB::OGE_F64 :
333           (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
334     break;
335   case ISD::SETLT:
336   case ISD::SETOLT:
337     LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
338           (VT == MVT::f64) ? RTLIB::OLT_F64 :
339           (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
340     break;
341   case ISD::SETLE:
342   case ISD::SETOLE:
343     LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
344           (VT == MVT::f64) ? RTLIB::OLE_F64 :
345           (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
346     break;
347   case ISD::SETGT:
348   case ISD::SETOGT:
349     LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
350           (VT == MVT::f64) ? RTLIB::OGT_F64 :
351           (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
352     break;
353   case ISD::SETO:
354     ShouldInvertCC = true;
355     [[fallthrough]];
356   case ISD::SETUO:
357     LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
358           (VT == MVT::f64) ? RTLIB::UO_F64 :
359           (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
360     break;
361   case ISD::SETONE:
362     // SETONE = O && UNE
363     ShouldInvertCC = true;
364     [[fallthrough]];
365   case ISD::SETUEQ:
366     LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
367           (VT == MVT::f64) ? RTLIB::UO_F64 :
368           (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
369     LC2 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
370           (VT == MVT::f64) ? RTLIB::OEQ_F64 :
371           (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
372     break;
373   default:
374     // Invert CC for unordered comparisons
375     ShouldInvertCC = true;
376     switch (CCCode) {
377     case ISD::SETULT:
378       LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
379             (VT == MVT::f64) ? RTLIB::OGE_F64 :
380             (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
381       break;
382     case ISD::SETULE:
383       LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
384             (VT == MVT::f64) ? RTLIB::OGT_F64 :
385             (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
386       break;
387     case ISD::SETUGT:
388       LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
389             (VT == MVT::f64) ? RTLIB::OLE_F64 :
390             (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
391       break;
392     case ISD::SETUGE:
393       LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
394             (VT == MVT::f64) ? RTLIB::OLT_F64 :
395             (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
396       break;
397     default: llvm_unreachable("Do not know how to soften this setcc!");
398     }
399   }
400 
401   // Use the target specific return value for comparison lib calls.
402   EVT RetVT = getCmpLibcallReturnType();
403   SDValue Ops[2] = {NewLHS, NewRHS};
404   TargetLowering::MakeLibCallOptions CallOptions;
405   EVT OpsVT[2] = { OldLHS.getValueType(),
406                    OldRHS.getValueType() };
407   CallOptions.setTypeListBeforeSoften(OpsVT, RetVT, true);
408   auto Call = makeLibCall(DAG, LC1, RetVT, Ops, CallOptions, dl, Chain);
409   NewLHS = Call.first;
410   NewRHS = DAG.getConstant(0, dl, RetVT);
411 
412   CCCode = getCmpLibcallCC(LC1);
413   if (ShouldInvertCC) {
414     assert(RetVT.isInteger());
415     CCCode = getSetCCInverse(CCCode, RetVT);
416   }
417 
418   if (LC2 == RTLIB::UNKNOWN_LIBCALL) {
419     // Update Chain.
420     Chain = Call.second;
421   } else {
422     EVT SetCCVT =
423         getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), RetVT);
424     SDValue Tmp = DAG.getSetCC(dl, SetCCVT, NewLHS, NewRHS, CCCode);
425     auto Call2 = makeLibCall(DAG, LC2, RetVT, Ops, CallOptions, dl, Chain);
426     CCCode = getCmpLibcallCC(LC2);
427     if (ShouldInvertCC)
428       CCCode = getSetCCInverse(CCCode, RetVT);
429     NewLHS = DAG.getSetCC(dl, SetCCVT, Call2.first, NewRHS, CCCode);
430     if (Chain)
431       Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Call.second,
432                           Call2.second);
433     NewLHS = DAG.getNode(ShouldInvertCC ? ISD::AND : ISD::OR, dl,
434                          Tmp.getValueType(), Tmp, NewLHS);
435     NewRHS = SDValue();
436   }
437 }
438 
439 /// Return the entry encoding for a jump table in the current function. The
440 /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
getJumpTableEncoding() const441 unsigned TargetLowering::getJumpTableEncoding() const {
442   // In non-pic modes, just use the address of a block.
443   if (!isPositionIndependent())
444     return MachineJumpTableInfo::EK_BlockAddress;
445 
446   // In PIC mode, if the target supports a GPRel32 directive, use it.
447   if (getTargetMachine().getMCAsmInfo()->getGPRel32Directive() != nullptr)
448     return MachineJumpTableInfo::EK_GPRel32BlockAddress;
449 
450   // Otherwise, use a label difference.
451   return MachineJumpTableInfo::EK_LabelDifference32;
452 }
453 
getPICJumpTableRelocBase(SDValue Table,SelectionDAG & DAG) const454 SDValue TargetLowering::getPICJumpTableRelocBase(SDValue Table,
455                                                  SelectionDAG &DAG) const {
456   // If our PIC model is GP relative, use the global offset table as the base.
457   unsigned JTEncoding = getJumpTableEncoding();
458 
459   if ((JTEncoding == MachineJumpTableInfo::EK_GPRel64BlockAddress) ||
460       (JTEncoding == MachineJumpTableInfo::EK_GPRel32BlockAddress))
461     return DAG.getGLOBAL_OFFSET_TABLE(getPointerTy(DAG.getDataLayout()));
462 
463   return Table;
464 }
465 
466 /// This returns the relocation base for the given PIC jumptable, the same as
467 /// getPICJumpTableRelocBase, but as an MCExpr.
468 const MCExpr *
getPICJumpTableRelocBaseExpr(const MachineFunction * MF,unsigned JTI,MCContext & Ctx) const469 TargetLowering::getPICJumpTableRelocBaseExpr(const MachineFunction *MF,
470                                              unsigned JTI,MCContext &Ctx) const{
471   // The normal PIC reloc base is the label at the start of the jump table.
472   return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx);
473 }
474 
475 bool
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const476 TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const {
477   const TargetMachine &TM = getTargetMachine();
478   const GlobalValue *GV = GA->getGlobal();
479 
480   // If the address is not even local to this DSO we will have to load it from
481   // a got and then add the offset.
482   if (!TM.shouldAssumeDSOLocal(*GV->getParent(), GV))
483     return false;
484 
485   // If the code is position independent we will have to add a base register.
486   if (isPositionIndependent())
487     return false;
488 
489   // Otherwise we can do it.
490   return true;
491 }
492 
493 //===----------------------------------------------------------------------===//
494 //  Optimization Methods
495 //===----------------------------------------------------------------------===//
496 
497 /// If the specified instruction has a constant integer operand and there are
498 /// bits set in that constant that are not demanded, then clear those bits and
499 /// return true.
ShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const500 bool TargetLowering::ShrinkDemandedConstant(SDValue Op,
501                                             const APInt &DemandedBits,
502                                             const APInt &DemandedElts,
503                                             TargetLoweringOpt &TLO) const {
504   SDLoc DL(Op);
505   unsigned Opcode = Op.getOpcode();
506 
507   // Do target-specific constant optimization.
508   if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
509     return TLO.New.getNode();
510 
511   // FIXME: ISD::SELECT, ISD::SELECT_CC
512   switch (Opcode) {
513   default:
514     break;
515   case ISD::XOR:
516   case ISD::AND:
517   case ISD::OR: {
518     auto *Op1C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
519     if (!Op1C || Op1C->isOpaque())
520       return false;
521 
522     // If this is a 'not' op, don't touch it because that's a canonical form.
523     const APInt &C = Op1C->getAPIntValue();
524     if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C))
525       return false;
526 
527     if (!C.isSubsetOf(DemandedBits)) {
528       EVT VT = Op.getValueType();
529       SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT);
530       SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC);
531       return TLO.CombineTo(Op, NewOp);
532     }
533 
534     break;
535   }
536   }
537 
538   return false;
539 }
540 
ShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,TargetLoweringOpt & TLO) const541 bool TargetLowering::ShrinkDemandedConstant(SDValue Op,
542                                             const APInt &DemandedBits,
543                                             TargetLoweringOpt &TLO) const {
544   EVT VT = Op.getValueType();
545   APInt DemandedElts = VT.isVector()
546                            ? APInt::getAllOnes(VT.getVectorNumElements())
547                            : APInt(1, 1);
548   return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO);
549 }
550 
551 /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free.
552 /// This uses isZExtFree and ZERO_EXTEND for the widening cast, but it could be
553 /// generalized for targets with other types of implicit widening casts.
ShrinkDemandedOp(SDValue Op,unsigned BitWidth,const APInt & Demanded,TargetLoweringOpt & TLO) const554 bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth,
555                                       const APInt &Demanded,
556                                       TargetLoweringOpt &TLO) const {
557   assert(Op.getNumOperands() == 2 &&
558          "ShrinkDemandedOp only supports binary operators!");
559   assert(Op.getNode()->getNumValues() == 1 &&
560          "ShrinkDemandedOp only supports nodes with one result!");
561 
562   SelectionDAG &DAG = TLO.DAG;
563   SDLoc dl(Op);
564 
565   // Early return, as this function cannot handle vector types.
566   if (Op.getValueType().isVector())
567     return false;
568 
569   // Don't do this if the node has another user, which may require the
570   // full value.
571   if (!Op.getNode()->hasOneUse())
572     return false;
573 
574   // Search for the smallest integer type with free casts to and from
575   // Op's type. For expedience, just check power-of-2 integer types.
576   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
577   unsigned DemandedSize = Demanded.getActiveBits();
578   unsigned SmallVTBits = DemandedSize;
579   if (!isPowerOf2_32(SmallVTBits))
580     SmallVTBits = NextPowerOf2(SmallVTBits);
581   for (; SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
582     EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
583     if (TLI.isTruncateFree(Op.getValueType(), SmallVT) &&
584         TLI.isZExtFree(SmallVT, Op.getValueType())) {
585       // We found a type with free casts.
586       SDValue X = DAG.getNode(
587           Op.getOpcode(), dl, SmallVT,
588           DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
589           DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
590       assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
591       SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, Op.getValueType(), X);
592       return TLO.CombineTo(Op, Z);
593     }
594   }
595   return false;
596 }
597 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,DAGCombinerInfo & DCI) const598 bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
599                                           DAGCombinerInfo &DCI) const {
600   SelectionDAG &DAG = DCI.DAG;
601   TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
602                         !DCI.isBeforeLegalizeOps());
603   KnownBits Known;
604 
605   bool Simplified = SimplifyDemandedBits(Op, DemandedBits, Known, TLO);
606   if (Simplified) {
607     DCI.AddToWorklist(Op.getNode());
608     DCI.CommitTargetLoweringOpt(TLO);
609   }
610   return Simplified;
611 }
612 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,DAGCombinerInfo & DCI) const613 bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
614                                           const APInt &DemandedElts,
615                                           DAGCombinerInfo &DCI) const {
616   SelectionDAG &DAG = DCI.DAG;
617   TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
618                         !DCI.isBeforeLegalizeOps());
619   KnownBits Known;
620 
621   bool Simplified =
622       SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO);
623   if (Simplified) {
624     DCI.AddToWorklist(Op.getNode());
625     DCI.CommitTargetLoweringOpt(TLO);
626   }
627   return Simplified;
628 }
629 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth,bool AssumeSingleUse) const630 bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
631                                           KnownBits &Known,
632                                           TargetLoweringOpt &TLO,
633                                           unsigned Depth,
634                                           bool AssumeSingleUse) const {
635   EVT VT = Op.getValueType();
636 
637   // Since the number of lanes in a scalable vector is unknown at compile time,
638   // we track one bit which is implicitly broadcast to all lanes.  This means
639   // that all lanes in a scalable vector are considered demanded.
640   APInt DemandedElts = VT.isFixedLengthVector()
641                            ? APInt::getAllOnes(VT.getVectorNumElements())
642                            : APInt(1, 1);
643   return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
644                               AssumeSingleUse);
645 }
646 
647 // TODO: Under what circumstances can we create nodes? Constant folding?
SimplifyMultipleUseDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const648 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
649     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
650     SelectionDAG &DAG, unsigned Depth) const {
651   EVT VT = Op.getValueType();
652 
653   // Limit search depth.
654   if (Depth >= SelectionDAG::MaxRecursionDepth)
655     return SDValue();
656 
657   // Ignore UNDEFs.
658   if (Op.isUndef())
659     return SDValue();
660 
661   // Not demanding any bits/elts from Op.
662   if (DemandedBits == 0 || DemandedElts == 0)
663     return DAG.getUNDEF(VT);
664 
665   bool IsLE = DAG.getDataLayout().isLittleEndian();
666   unsigned NumElts = DemandedElts.getBitWidth();
667   unsigned BitWidth = DemandedBits.getBitWidth();
668   KnownBits LHSKnown, RHSKnown;
669   switch (Op.getOpcode()) {
670   case ISD::BITCAST: {
671     if (VT.isScalableVector())
672       return SDValue();
673 
674     SDValue Src = peekThroughBitcasts(Op.getOperand(0));
675     EVT SrcVT = Src.getValueType();
676     EVT DstVT = Op.getValueType();
677     if (SrcVT == DstVT)
678       return Src;
679 
680     unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
681     unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
682     if (NumSrcEltBits == NumDstEltBits)
683       if (SDValue V = SimplifyMultipleUseDemandedBits(
684               Src, DemandedBits, DemandedElts, DAG, Depth + 1))
685         return DAG.getBitcast(DstVT, V);
686 
687     if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
688       unsigned Scale = NumDstEltBits / NumSrcEltBits;
689       unsigned NumSrcElts = SrcVT.getVectorNumElements();
690       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
691       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
692       for (unsigned i = 0; i != Scale; ++i) {
693         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
694         unsigned BitOffset = EltOffset * NumSrcEltBits;
695         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
696         if (!Sub.isZero()) {
697           DemandedSrcBits |= Sub;
698           for (unsigned j = 0; j != NumElts; ++j)
699             if (DemandedElts[j])
700               DemandedSrcElts.setBit((j * Scale) + i);
701         }
702       }
703 
704       if (SDValue V = SimplifyMultipleUseDemandedBits(
705               Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
706         return DAG.getBitcast(DstVT, V);
707     }
708 
709     // TODO - bigendian once we have test coverage.
710     if (IsLE && (NumSrcEltBits % NumDstEltBits) == 0) {
711       unsigned Scale = NumSrcEltBits / NumDstEltBits;
712       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
713       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
714       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
715       for (unsigned i = 0; i != NumElts; ++i)
716         if (DemandedElts[i]) {
717           unsigned Offset = (i % Scale) * NumDstEltBits;
718           DemandedSrcBits.insertBits(DemandedBits, Offset);
719           DemandedSrcElts.setBit(i / Scale);
720         }
721 
722       if (SDValue V = SimplifyMultipleUseDemandedBits(
723               Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
724         return DAG.getBitcast(DstVT, V);
725     }
726 
727     break;
728   }
729   case ISD::AND: {
730     LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
731     RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
732 
733     // If all of the demanded bits are known 1 on one side, return the other.
734     // These bits cannot contribute to the result of the 'and' in this
735     // context.
736     if (DemandedBits.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
737       return Op.getOperand(0);
738     if (DemandedBits.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
739       return Op.getOperand(1);
740     break;
741   }
742   case ISD::OR: {
743     LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
744     RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
745 
746     // If all of the demanded bits are known zero on one side, return the
747     // other.  These bits cannot contribute to the result of the 'or' in this
748     // context.
749     if (DemandedBits.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
750       return Op.getOperand(0);
751     if (DemandedBits.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
752       return Op.getOperand(1);
753     break;
754   }
755   case ISD::XOR: {
756     LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
757     RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
758 
759     // If all of the demanded bits are known zero on one side, return the
760     // other.
761     if (DemandedBits.isSubsetOf(RHSKnown.Zero))
762       return Op.getOperand(0);
763     if (DemandedBits.isSubsetOf(LHSKnown.Zero))
764       return Op.getOperand(1);
765     break;
766   }
767   case ISD::SHL: {
768     // If we are only demanding sign bits then we can use the shift source
769     // directly.
770     if (const APInt *MaxSA =
771             DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
772       SDValue Op0 = Op.getOperand(0);
773       unsigned ShAmt = MaxSA->getZExtValue();
774       unsigned NumSignBits =
775           DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
776       unsigned UpperDemandedBits = BitWidth - DemandedBits.countTrailingZeros();
777       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
778         return Op0;
779     }
780     break;
781   }
782   case ISD::SETCC: {
783     SDValue Op0 = Op.getOperand(0);
784     SDValue Op1 = Op.getOperand(1);
785     ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
786     // If (1) we only need the sign-bit, (2) the setcc operands are the same
787     // width as the setcc result, and (3) the result of a setcc conforms to 0 or
788     // -1, we may be able to bypass the setcc.
789     if (DemandedBits.isSignMask() &&
790         Op0.getScalarValueSizeInBits() == BitWidth &&
791         getBooleanContents(Op0.getValueType()) ==
792             BooleanContent::ZeroOrNegativeOneBooleanContent) {
793       // If we're testing X < 0, then this compare isn't needed - just use X!
794       // FIXME: We're limiting to integer types here, but this should also work
795       // if we don't care about FP signed-zero. The use of SETLT with FP means
796       // that we don't care about NaNs.
797       if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
798           (isNullConstant(Op1) || ISD::isBuildVectorAllZeros(Op1.getNode())))
799         return Op0;
800     }
801     break;
802   }
803   case ISD::SIGN_EXTEND_INREG: {
804     // If none of the extended bits are demanded, eliminate the sextinreg.
805     SDValue Op0 = Op.getOperand(0);
806     EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
807     unsigned ExBits = ExVT.getScalarSizeInBits();
808     if (DemandedBits.getActiveBits() <= ExBits)
809       return Op0;
810     // If the input is already sign extended, just drop the extension.
811     unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
812     if (NumSignBits >= (BitWidth - ExBits + 1))
813       return Op0;
814     break;
815   }
816   case ISD::ANY_EXTEND_VECTOR_INREG:
817   case ISD::SIGN_EXTEND_VECTOR_INREG:
818   case ISD::ZERO_EXTEND_VECTOR_INREG: {
819     if (VT.isScalableVector())
820       return SDValue();
821 
822     // If we only want the lowest element and none of extended bits, then we can
823     // return the bitcasted source vector.
824     SDValue Src = Op.getOperand(0);
825     EVT SrcVT = Src.getValueType();
826     EVT DstVT = Op.getValueType();
827     if (IsLE && DemandedElts == 1 &&
828         DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
829         DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
830       return DAG.getBitcast(DstVT, Src);
831     }
832     break;
833   }
834   case ISD::INSERT_VECTOR_ELT: {
835     if (VT.isScalableVector())
836       return SDValue();
837 
838     // If we don't demand the inserted element, return the base vector.
839     SDValue Vec = Op.getOperand(0);
840     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
841     EVT VecVT = Vec.getValueType();
842     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
843         !DemandedElts[CIdx->getZExtValue()])
844       return Vec;
845     break;
846   }
847   case ISD::INSERT_SUBVECTOR: {
848     if (VT.isScalableVector())
849       return SDValue();
850 
851     SDValue Vec = Op.getOperand(0);
852     SDValue Sub = Op.getOperand(1);
853     uint64_t Idx = Op.getConstantOperandVal(2);
854     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
855     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
856     // If we don't demand the inserted subvector, return the base vector.
857     if (DemandedSubElts == 0)
858       return Vec;
859     // If this simply widens the lowest subvector, see if we can do it earlier.
860     // TODO: REMOVE ME - SimplifyMultipleUseDemandedBits shouldn't be creating
861     // general nodes like this.
862     if (Idx == 0 && Vec.isUndef()) {
863       if (SDValue NewSub = SimplifyMultipleUseDemandedBits(
864               Sub, DemandedBits, DemandedSubElts, DAG, Depth + 1))
865         return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
866                            Op.getOperand(0), NewSub, Op.getOperand(2));
867     }
868     break;
869   }
870   case ISD::VECTOR_SHUFFLE: {
871     assert(!VT.isScalableVector());
872     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
873 
874     // If all the demanded elts are from one operand and are inline,
875     // then we can use the operand directly.
876     bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
877     for (unsigned i = 0; i != NumElts; ++i) {
878       int M = ShuffleMask[i];
879       if (M < 0 || !DemandedElts[i])
880         continue;
881       AllUndef = false;
882       IdentityLHS &= (M == (int)i);
883       IdentityRHS &= ((M - NumElts) == i);
884     }
885 
886     if (AllUndef)
887       return DAG.getUNDEF(Op.getValueType());
888     if (IdentityLHS)
889       return Op.getOperand(0);
890     if (IdentityRHS)
891       return Op.getOperand(1);
892     break;
893   }
894   default:
895     // TODO: Probably okay to remove after audit; here to reduce change size
896     // in initial enablement patch for scalable vectors
897     if (VT.isScalableVector())
898       return SDValue();
899 
900     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
901       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
902               Op, DemandedBits, DemandedElts, DAG, Depth))
903         return V;
904     break;
905   }
906   return SDValue();
907 }
908 
SimplifyMultipleUseDemandedBits(SDValue Op,const APInt & DemandedBits,SelectionDAG & DAG,unsigned Depth) const909 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
910     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
911     unsigned Depth) const {
912   EVT VT = Op.getValueType();
913   // Since the number of lanes in a scalable vector is unknown at compile time,
914   // we track one bit which is implicitly broadcast to all lanes.  This means
915   // that all lanes in a scalable vector are considered demanded.
916   APInt DemandedElts = VT.isFixedLengthVector()
917                            ? APInt::getAllOnes(VT.getVectorNumElements())
918                            : APInt(1, 1);
919   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
920                                          Depth);
921 }
922 
SimplifyMultipleUseDemandedVectorElts(SDValue Op,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const923 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
924     SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
925     unsigned Depth) const {
926   APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
927   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
928                                          Depth);
929 }
930 
931 // Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
932 //      or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
combineShiftToAVG(SDValue Op,SelectionDAG & DAG,const TargetLowering & TLI,const APInt & DemandedBits,const APInt & DemandedElts,unsigned Depth)933 static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
934                                  const TargetLowering &TLI,
935                                  const APInt &DemandedBits,
936                                  const APInt &DemandedElts,
937                                  unsigned Depth) {
938   assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
939          "SRL or SRA node is required here!");
940   // Is the right shift using an immediate value of 1?
941   ConstantSDNode *N1C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
942   if (!N1C || !N1C->isOne())
943     return SDValue();
944 
945   // We are looking for an avgfloor
946   // add(ext, ext)
947   // or one of these as a avgceil
948   // add(add(ext, ext), 1)
949   // add(add(ext, 1), ext)
950   // add(ext, add(ext, 1))
951   SDValue Add = Op.getOperand(0);
952   if (Add.getOpcode() != ISD::ADD)
953     return SDValue();
954 
955   SDValue ExtOpA = Add.getOperand(0);
956   SDValue ExtOpB = Add.getOperand(1);
957   auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3) {
958     ConstantSDNode *ConstOp;
959     if ((ConstOp = isConstOrConstSplat(Op1, DemandedElts)) &&
960         ConstOp->isOne()) {
961       ExtOpA = Op2;
962       ExtOpB = Op3;
963       return true;
964     }
965     if ((ConstOp = isConstOrConstSplat(Op2, DemandedElts)) &&
966         ConstOp->isOne()) {
967       ExtOpA = Op1;
968       ExtOpB = Op3;
969       return true;
970     }
971     if ((ConstOp = isConstOrConstSplat(Op3, DemandedElts)) &&
972         ConstOp->isOne()) {
973       ExtOpA = Op1;
974       ExtOpB = Op2;
975       return true;
976     }
977     return false;
978   };
979   bool IsCeil =
980       (ExtOpA.getOpcode() == ISD::ADD &&
981        MatchOperands(ExtOpA.getOperand(0), ExtOpA.getOperand(1), ExtOpB)) ||
982       (ExtOpB.getOpcode() == ISD::ADD &&
983        MatchOperands(ExtOpB.getOperand(0), ExtOpB.getOperand(1), ExtOpA));
984 
985   // If the shift is signed (sra):
986   //  - Needs >= 2 sign bit for both operands.
987   //  - Needs >= 2 zero bits.
988   // If the shift is unsigned (srl):
989   //  - Needs >= 1 zero bit for both operands.
990   //  - Needs 1 demanded bit zero and >= 2 sign bits.
991   unsigned ShiftOpc = Op.getOpcode();
992   bool IsSigned = false;
993   unsigned KnownBits;
994   unsigned NumSignedA = DAG.ComputeNumSignBits(ExtOpA, DemandedElts, Depth);
995   unsigned NumSignedB = DAG.ComputeNumSignBits(ExtOpB, DemandedElts, Depth);
996   unsigned NumSigned = std::min(NumSignedA, NumSignedB) - 1;
997   unsigned NumZeroA =
998       DAG.computeKnownBits(ExtOpA, DemandedElts, Depth).countMinLeadingZeros();
999   unsigned NumZeroB =
1000       DAG.computeKnownBits(ExtOpB, DemandedElts, Depth).countMinLeadingZeros();
1001   unsigned NumZero = std::min(NumZeroA, NumZeroB);
1002 
1003   switch (ShiftOpc) {
1004   default:
1005     llvm_unreachable("Unexpected ShiftOpc in combineShiftToAVG");
1006   case ISD::SRA: {
1007     if (NumZero >= 2 && NumSigned < NumZero) {
1008       IsSigned = false;
1009       KnownBits = NumZero;
1010       break;
1011     }
1012     if (NumSigned >= 1) {
1013       IsSigned = true;
1014       KnownBits = NumSigned;
1015       break;
1016     }
1017     return SDValue();
1018   }
1019   case ISD::SRL: {
1020     if (NumZero >= 1 && NumSigned < NumZero) {
1021       IsSigned = false;
1022       KnownBits = NumZero;
1023       break;
1024     }
1025     if (NumSigned >= 1 && DemandedBits.isSignBitClear()) {
1026       IsSigned = true;
1027       KnownBits = NumSigned;
1028       break;
1029     }
1030     return SDValue();
1031   }
1032   }
1033 
1034   unsigned AVGOpc = IsCeil ? (IsSigned ? ISD::AVGCEILS : ISD::AVGCEILU)
1035                            : (IsSigned ? ISD::AVGFLOORS : ISD::AVGFLOORU);
1036 
1037   // Find the smallest power-2 type that is legal for this vector size and
1038   // operation, given the original type size and the number of known sign/zero
1039   // bits.
1040   EVT VT = Op.getValueType();
1041   unsigned MinWidth =
1042       std::max<unsigned>(VT.getScalarSizeInBits() - KnownBits, 8);
1043   EVT NVT = EVT::getIntegerVT(*DAG.getContext(), PowerOf2Ceil(MinWidth));
1044   if (VT.isVector())
1045     NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1046   if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT))
1047     return SDValue();
1048 
1049   SDLoc DL(Op);
1050   SDValue ResultAVG =
1051       DAG.getNode(AVGOpc, DL, NVT, DAG.getNode(ISD::TRUNCATE, DL, NVT, ExtOpA),
1052                   DAG.getNode(ISD::TRUNCATE, DL, NVT, ExtOpB));
1053   return DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT,
1054                      ResultAVG);
1055 }
1056 
1057 /// Look at Op. At this point, we know that only the OriginalDemandedBits of the
1058 /// result of Op are ever used downstream. If we can use this information to
1059 /// simplify Op, create a new simplified DAG node and return true, returning the
1060 /// original and new nodes in Old and New. Otherwise, analyze the expression and
1061 /// return a mask of Known bits for the expression (used to simplify the
1062 /// caller).  The Known bits may only be accurate for those bits in the
1063 /// OriginalDemandedBits and OriginalDemandedElts.
SimplifyDemandedBits(SDValue Op,const APInt & OriginalDemandedBits,const APInt & OriginalDemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth,bool AssumeSingleUse) const1064 bool TargetLowering::SimplifyDemandedBits(
1065     SDValue Op, const APInt &OriginalDemandedBits,
1066     const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
1067     unsigned Depth, bool AssumeSingleUse) const {
1068   unsigned BitWidth = OriginalDemandedBits.getBitWidth();
1069   assert(Op.getScalarValueSizeInBits() == BitWidth &&
1070          "Mask size mismatches value type size!");
1071 
1072   // Don't know anything.
1073   Known = KnownBits(BitWidth);
1074 
1075   EVT VT = Op.getValueType();
1076   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
1077   unsigned NumElts = OriginalDemandedElts.getBitWidth();
1078   assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
1079          "Unexpected vector size");
1080 
1081   APInt DemandedBits = OriginalDemandedBits;
1082   APInt DemandedElts = OriginalDemandedElts;
1083   SDLoc dl(Op);
1084   auto &DL = TLO.DAG.getDataLayout();
1085 
1086   // Undef operand.
1087   if (Op.isUndef())
1088     return false;
1089 
1090   // We can't simplify target constants.
1091   if (Op.getOpcode() == ISD::TargetConstant)
1092     return false;
1093 
1094   if (Op.getOpcode() == ISD::Constant) {
1095     // We know all of the bits for a constant!
1096     Known = KnownBits::makeConstant(cast<ConstantSDNode>(Op)->getAPIntValue());
1097     return false;
1098   }
1099 
1100   if (Op.getOpcode() == ISD::ConstantFP) {
1101     // We know all of the bits for a floating point constant!
1102     Known = KnownBits::makeConstant(
1103         cast<ConstantFPSDNode>(Op)->getValueAPF().bitcastToAPInt());
1104     return false;
1105   }
1106 
1107   // Other users may use these bits.
1108   bool HasMultiUse = false;
1109   if (!AssumeSingleUse && !Op.getNode()->hasOneUse()) {
1110     if (Depth >= SelectionDAG::MaxRecursionDepth) {
1111       // Limit search depth.
1112       return false;
1113     }
1114     // Allow multiple uses, just set the DemandedBits/Elts to all bits.
1115     DemandedBits = APInt::getAllOnes(BitWidth);
1116     DemandedElts = APInt::getAllOnes(NumElts);
1117     HasMultiUse = true;
1118   } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) {
1119     // Not demanding any bits/elts from Op.
1120     return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1121   } else if (Depth >= SelectionDAG::MaxRecursionDepth) {
1122     // Limit search depth.
1123     return false;
1124   }
1125 
1126   KnownBits Known2;
1127   switch (Op.getOpcode()) {
1128   case ISD::SCALAR_TO_VECTOR: {
1129     if (VT.isScalableVector())
1130       return false;
1131     if (!DemandedElts[0])
1132       return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1133 
1134     KnownBits SrcKnown;
1135     SDValue Src = Op.getOperand(0);
1136     unsigned SrcBitWidth = Src.getScalarValueSizeInBits();
1137     APInt SrcDemandedBits = DemandedBits.zext(SrcBitWidth);
1138     if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1))
1139       return true;
1140 
1141     // Upper elements are undef, so only get the knownbits if we just demand
1142     // the bottom element.
1143     if (DemandedElts == 1)
1144       Known = SrcKnown.anyextOrTrunc(BitWidth);
1145     break;
1146   }
1147   case ISD::BUILD_VECTOR:
1148     // Collect the known bits that are shared by every demanded element.
1149     // TODO: Call SimplifyDemandedBits for non-constant demanded elements.
1150     Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1151     return false; // Don't fall through, will infinitely loop.
1152   case ISD::LOAD: {
1153     auto *LD = cast<LoadSDNode>(Op);
1154     if (getTargetConstantFromLoad(LD)) {
1155       Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1156       return false; // Don't fall through, will infinitely loop.
1157     }
1158     if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
1159       // If this is a ZEXTLoad and we are looking at the loaded value.
1160       EVT MemVT = LD->getMemoryVT();
1161       unsigned MemBits = MemVT.getScalarSizeInBits();
1162       Known.Zero.setBitsFrom(MemBits);
1163       return false; // Don't fall through, will infinitely loop.
1164     }
1165     break;
1166   }
1167   case ISD::INSERT_VECTOR_ELT: {
1168     if (VT.isScalableVector())
1169       return false;
1170     SDValue Vec = Op.getOperand(0);
1171     SDValue Scl = Op.getOperand(1);
1172     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
1173     EVT VecVT = Vec.getValueType();
1174 
1175     // If index isn't constant, assume we need all vector elements AND the
1176     // inserted element.
1177     APInt DemandedVecElts(DemandedElts);
1178     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) {
1179       unsigned Idx = CIdx->getZExtValue();
1180       DemandedVecElts.clearBit(Idx);
1181 
1182       // Inserted element is not required.
1183       if (!DemandedElts[Idx])
1184         return TLO.CombineTo(Op, Vec);
1185     }
1186 
1187     KnownBits KnownScl;
1188     unsigned NumSclBits = Scl.getScalarValueSizeInBits();
1189     APInt DemandedSclBits = DemandedBits.zextOrTrunc(NumSclBits);
1190     if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1191       return true;
1192 
1193     Known = KnownScl.anyextOrTrunc(BitWidth);
1194 
1195     KnownBits KnownVec;
1196     if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO,
1197                              Depth + 1))
1198       return true;
1199 
1200     if (!!DemandedVecElts)
1201       Known = KnownBits::commonBits(Known, KnownVec);
1202 
1203     return false;
1204   }
1205   case ISD::INSERT_SUBVECTOR: {
1206     if (VT.isScalableVector())
1207       return false;
1208     // Demand any elements from the subvector and the remainder from the src its
1209     // inserted into.
1210     SDValue Src = Op.getOperand(0);
1211     SDValue Sub = Op.getOperand(1);
1212     uint64_t Idx = Op.getConstantOperandVal(2);
1213     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
1214     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
1215     APInt DemandedSrcElts = DemandedElts;
1216     DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
1217 
1218     KnownBits KnownSub, KnownSrc;
1219     if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
1220                              Depth + 1))
1221       return true;
1222     if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, KnownSrc, TLO,
1223                              Depth + 1))
1224       return true;
1225 
1226     Known.Zero.setAllBits();
1227     Known.One.setAllBits();
1228     if (!!DemandedSubElts)
1229       Known = KnownBits::commonBits(Known, KnownSub);
1230     if (!!DemandedSrcElts)
1231       Known = KnownBits::commonBits(Known, KnownSrc);
1232 
1233     // Attempt to avoid multi-use src if we don't need anything from it.
1234     if (!DemandedBits.isAllOnes() || !DemandedSubElts.isAllOnes() ||
1235         !DemandedSrcElts.isAllOnes()) {
1236       SDValue NewSub = SimplifyMultipleUseDemandedBits(
1237           Sub, DemandedBits, DemandedSubElts, TLO.DAG, Depth + 1);
1238       SDValue NewSrc = SimplifyMultipleUseDemandedBits(
1239           Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1240       if (NewSub || NewSrc) {
1241         NewSub = NewSub ? NewSub : Sub;
1242         NewSrc = NewSrc ? NewSrc : Src;
1243         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc, NewSub,
1244                                         Op.getOperand(2));
1245         return TLO.CombineTo(Op, NewOp);
1246       }
1247     }
1248     break;
1249   }
1250   case ISD::EXTRACT_SUBVECTOR: {
1251     if (VT.isScalableVector())
1252       return false;
1253     // Offset the demanded elts by the subvector index.
1254     SDValue Src = Op.getOperand(0);
1255     if (Src.getValueType().isScalableVector())
1256       break;
1257     uint64_t Idx = Op.getConstantOperandVal(1);
1258     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
1259     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
1260 
1261     if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, Known, TLO,
1262                              Depth + 1))
1263       return true;
1264 
1265     // Attempt to avoid multi-use src if we don't need anything from it.
1266     if (!DemandedBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
1267       SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
1268           Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1269       if (DemandedSrc) {
1270         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc,
1271                                         Op.getOperand(1));
1272         return TLO.CombineTo(Op, NewOp);
1273       }
1274     }
1275     break;
1276   }
1277   case ISD::CONCAT_VECTORS: {
1278     if (VT.isScalableVector())
1279       return false;
1280     Known.Zero.setAllBits();
1281     Known.One.setAllBits();
1282     EVT SubVT = Op.getOperand(0).getValueType();
1283     unsigned NumSubVecs = Op.getNumOperands();
1284     unsigned NumSubElts = SubVT.getVectorNumElements();
1285     for (unsigned i = 0; i != NumSubVecs; ++i) {
1286       APInt DemandedSubElts =
1287           DemandedElts.extractBits(NumSubElts, i * NumSubElts);
1288       if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts,
1289                                Known2, TLO, Depth + 1))
1290         return true;
1291       // Known bits are shared by every demanded subvector element.
1292       if (!!DemandedSubElts)
1293         Known = KnownBits::commonBits(Known, Known2);
1294     }
1295     break;
1296   }
1297   case ISD::VECTOR_SHUFFLE: {
1298     assert(!VT.isScalableVector());
1299     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
1300 
1301     // Collect demanded elements from shuffle operands..
1302     APInt DemandedLHS, DemandedRHS;
1303     if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
1304                                 DemandedRHS))
1305       break;
1306 
1307     if (!!DemandedLHS || !!DemandedRHS) {
1308       SDValue Op0 = Op.getOperand(0);
1309       SDValue Op1 = Op.getOperand(1);
1310 
1311       Known.Zero.setAllBits();
1312       Known.One.setAllBits();
1313       if (!!DemandedLHS) {
1314         if (SimplifyDemandedBits(Op0, DemandedBits, DemandedLHS, Known2, TLO,
1315                                  Depth + 1))
1316           return true;
1317         Known = KnownBits::commonBits(Known, Known2);
1318       }
1319       if (!!DemandedRHS) {
1320         if (SimplifyDemandedBits(Op1, DemandedBits, DemandedRHS, Known2, TLO,
1321                                  Depth + 1))
1322           return true;
1323         Known = KnownBits::commonBits(Known, Known2);
1324       }
1325 
1326       // Attempt to avoid multi-use ops if we don't need anything from them.
1327       SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1328           Op0, DemandedBits, DemandedLHS, TLO.DAG, Depth + 1);
1329       SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1330           Op1, DemandedBits, DemandedRHS, TLO.DAG, Depth + 1);
1331       if (DemandedOp0 || DemandedOp1) {
1332         Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1333         Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1334         SDValue NewOp = TLO.DAG.getVectorShuffle(VT, dl, Op0, Op1, ShuffleMask);
1335         return TLO.CombineTo(Op, NewOp);
1336       }
1337     }
1338     break;
1339   }
1340   case ISD::AND: {
1341     SDValue Op0 = Op.getOperand(0);
1342     SDValue Op1 = Op.getOperand(1);
1343 
1344     // If the RHS is a constant, check to see if the LHS would be zero without
1345     // using the bits from the RHS.  Below, we use knowledge about the RHS to
1346     // simplify the LHS, here we're using information from the LHS to simplify
1347     // the RHS.
1348     if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1)) {
1349       // Do not increment Depth here; that can cause an infinite loop.
1350       KnownBits LHSKnown = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth);
1351       // If the LHS already has zeros where RHSC does, this 'and' is dead.
1352       if ((LHSKnown.Zero & DemandedBits) ==
1353           (~RHSC->getAPIntValue() & DemandedBits))
1354         return TLO.CombineTo(Op, Op0);
1355 
1356       // If any of the set bits in the RHS are known zero on the LHS, shrink
1357       // the constant.
1358       if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits,
1359                                  DemandedElts, TLO))
1360         return true;
1361 
1362       // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its
1363       // constant, but if this 'and' is only clearing bits that were just set by
1364       // the xor, then this 'and' can be eliminated by shrinking the mask of
1365       // the xor. For example, for a 32-bit X:
1366       // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1
1367       if (isBitwiseNot(Op0) && Op0.hasOneUse() &&
1368           LHSKnown.One == ~RHSC->getAPIntValue()) {
1369         SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1);
1370         return TLO.CombineTo(Op, Xor);
1371       }
1372     }
1373 
1374     // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
1375     // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
1376     if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
1377         (Op0.getOperand(0).isUndef() ||
1378          ISD::isBuildVectorOfConstantSDNodes(Op0.getOperand(0).getNode())) &&
1379         Op0->hasOneUse()) {
1380       unsigned NumSubElts =
1381           Op0.getOperand(1).getValueType().getVectorNumElements();
1382       unsigned SubIdx = Op0.getConstantOperandVal(2);
1383       APInt DemandedSub =
1384           APInt::getBitsSet(NumElts, SubIdx, SubIdx + NumSubElts);
1385       KnownBits KnownSubMask =
1386           TLO.DAG.computeKnownBits(Op1, DemandedSub & DemandedElts, Depth + 1);
1387       if (DemandedBits.isSubsetOf(KnownSubMask.One)) {
1388         SDValue NewAnd =
1389             TLO.DAG.getNode(ISD::AND, dl, VT, Op0.getOperand(0), Op1);
1390         SDValue NewInsert =
1391             TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, NewAnd,
1392                             Op0.getOperand(1), Op0.getOperand(2));
1393         return TLO.CombineTo(Op, NewInsert);
1394       }
1395     }
1396 
1397     if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1398                              Depth + 1))
1399       return true;
1400     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1401     if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts,
1402                              Known2, TLO, Depth + 1))
1403       return true;
1404     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
1405 
1406     // If all of the demanded bits are known one on one side, return the other.
1407     // These bits cannot contribute to the result of the 'and'.
1408     if (DemandedBits.isSubsetOf(Known2.Zero | Known.One))
1409       return TLO.CombineTo(Op, Op0);
1410     if (DemandedBits.isSubsetOf(Known.Zero | Known2.One))
1411       return TLO.CombineTo(Op, Op1);
1412     // If all of the demanded bits in the inputs are known zeros, return zero.
1413     if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1414       return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
1415     // If the RHS is a constant, see if we can simplify it.
1416     if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts,
1417                                TLO))
1418       return true;
1419     // If the operation can be done in a smaller type, do so.
1420     if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1421       return true;
1422 
1423     // Attempt to avoid multi-use ops if we don't need anything from them.
1424     if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1425       SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1426           Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1427       SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1428           Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1429       if (DemandedOp0 || DemandedOp1) {
1430         Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1431         Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1432         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1433         return TLO.CombineTo(Op, NewOp);
1434       }
1435     }
1436 
1437     Known &= Known2;
1438     break;
1439   }
1440   case ISD::OR: {
1441     SDValue Op0 = Op.getOperand(0);
1442     SDValue Op1 = Op.getOperand(1);
1443 
1444     if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1445                              Depth + 1))
1446       return true;
1447     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1448     if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts,
1449                              Known2, TLO, Depth + 1))
1450       return true;
1451     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
1452 
1453     // If all of the demanded bits are known zero on one side, return the other.
1454     // These bits cannot contribute to the result of the 'or'.
1455     if (DemandedBits.isSubsetOf(Known2.One | Known.Zero))
1456       return TLO.CombineTo(Op, Op0);
1457     if (DemandedBits.isSubsetOf(Known.One | Known2.Zero))
1458       return TLO.CombineTo(Op, Op1);
1459     // If the RHS is a constant, see if we can simplify it.
1460     if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1461       return true;
1462     // If the operation can be done in a smaller type, do so.
1463     if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1464       return true;
1465 
1466     // Attempt to avoid multi-use ops if we don't need anything from them.
1467     if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1468       SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1469           Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1470       SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1471           Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1472       if (DemandedOp0 || DemandedOp1) {
1473         Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1474         Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1475         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1476         return TLO.CombineTo(Op, NewOp);
1477       }
1478     }
1479 
1480     // (or (and X, C1), (and (or X, Y), C2)) -> (or (and X, C1|C2), (and Y, C2))
1481     // TODO: Use SimplifyMultipleUseDemandedBits to peek through masks.
1482     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::AND &&
1483         Op0->hasOneUse() && Op1->hasOneUse()) {
1484       // Attempt to match all commutations - m_c_Or would've been useful!
1485       for (int I = 0; I != 2; ++I) {
1486         SDValue X = Op.getOperand(I).getOperand(0);
1487         SDValue C1 = Op.getOperand(I).getOperand(1);
1488         SDValue Alt = Op.getOperand(1 - I).getOperand(0);
1489         SDValue C2 = Op.getOperand(1 - I).getOperand(1);
1490         if (Alt.getOpcode() == ISD::OR) {
1491           for (int J = 0; J != 2; ++J) {
1492             if (X == Alt.getOperand(J)) {
1493               SDValue Y = Alt.getOperand(1 - J);
1494               if (SDValue C12 = TLO.DAG.FoldConstantArithmetic(ISD::OR, dl, VT,
1495                                                                {C1, C2})) {
1496                 SDValue MaskX = TLO.DAG.getNode(ISD::AND, dl, VT, X, C12);
1497                 SDValue MaskY = TLO.DAG.getNode(ISD::AND, dl, VT, Y, C2);
1498                 return TLO.CombineTo(
1499                     Op, TLO.DAG.getNode(ISD::OR, dl, VT, MaskX, MaskY));
1500               }
1501             }
1502           }
1503         }
1504       }
1505     }
1506 
1507     Known |= Known2;
1508     break;
1509   }
1510   case ISD::XOR: {
1511     SDValue Op0 = Op.getOperand(0);
1512     SDValue Op1 = Op.getOperand(1);
1513 
1514     if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1515                              Depth + 1))
1516       return true;
1517     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1518     if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO,
1519                              Depth + 1))
1520       return true;
1521     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
1522 
1523     // If all of the demanded bits are known zero on one side, return the other.
1524     // These bits cannot contribute to the result of the 'xor'.
1525     if (DemandedBits.isSubsetOf(Known.Zero))
1526       return TLO.CombineTo(Op, Op0);
1527     if (DemandedBits.isSubsetOf(Known2.Zero))
1528       return TLO.CombineTo(Op, Op1);
1529     // If the operation can be done in a smaller type, do so.
1530     if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1531       return true;
1532 
1533     // If all of the unknown bits are known to be zero on one side or the other
1534     // turn this into an *inclusive* or.
1535     //    e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
1536     if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1537       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
1538 
1539     ConstantSDNode *C = isConstOrConstSplat(Op1, DemandedElts);
1540     if (C) {
1541       // If one side is a constant, and all of the set bits in the constant are
1542       // also known set on the other side, turn this into an AND, as we know
1543       // the bits will be cleared.
1544       //    e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
1545       // NB: it is okay if more bits are known than are requested
1546       if (C->getAPIntValue() == Known2.One) {
1547         SDValue ANDC =
1548             TLO.DAG.getConstant(~C->getAPIntValue() & DemandedBits, dl, VT);
1549         return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, ANDC));
1550       }
1551 
1552       // If the RHS is a constant, see if we can change it. Don't alter a -1
1553       // constant because that's a 'not' op, and that is better for combining
1554       // and codegen.
1555       if (!C->isAllOnes() && DemandedBits.isSubsetOf(C->getAPIntValue())) {
1556         // We're flipping all demanded bits. Flip the undemanded bits too.
1557         SDValue New = TLO.DAG.getNOT(dl, Op0, VT);
1558         return TLO.CombineTo(Op, New);
1559       }
1560 
1561       unsigned Op0Opcode = Op0.getOpcode();
1562       if ((Op0Opcode == ISD::SRL || Op0Opcode == ISD::SHL) && Op0.hasOneUse()) {
1563         if (ConstantSDNode *ShiftC =
1564                 isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
1565           // Don't crash on an oversized shift. We can not guarantee that a
1566           // bogus shift has been simplified to undef.
1567           if (ShiftC->getAPIntValue().ult(BitWidth)) {
1568             uint64_t ShiftAmt = ShiftC->getZExtValue();
1569             APInt Ones = APInt::getAllOnes(BitWidth);
1570             Ones = Op0Opcode == ISD::SHL ? Ones.shl(ShiftAmt)
1571                                          : Ones.lshr(ShiftAmt);
1572             const TargetLowering &TLI = TLO.DAG.getTargetLoweringInfo();
1573             if ((DemandedBits & C->getAPIntValue()) == (DemandedBits & Ones) &&
1574                 TLI.isDesirableToCommuteXorWithShift(Op.getNode())) {
1575               // If the xor constant is a demanded mask, do a 'not' before the
1576               // shift:
1577               // xor (X << ShiftC), XorC --> (not X) << ShiftC
1578               // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
1579               SDValue Not = TLO.DAG.getNOT(dl, Op0.getOperand(0), VT);
1580               return TLO.CombineTo(Op, TLO.DAG.getNode(Op0Opcode, dl, VT, Not,
1581                                                        Op0.getOperand(1)));
1582             }
1583           }
1584         }
1585       }
1586     }
1587 
1588     // If we can't turn this into a 'not', try to shrink the constant.
1589     if (!C || !C->isAllOnes())
1590       if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1591         return true;
1592 
1593     // Attempt to avoid multi-use ops if we don't need anything from them.
1594     if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1595       SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1596           Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1597       SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1598           Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1599       if (DemandedOp0 || DemandedOp1) {
1600         Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1601         Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1602         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1603         return TLO.CombineTo(Op, NewOp);
1604       }
1605     }
1606 
1607     Known ^= Known2;
1608     break;
1609   }
1610   case ISD::SELECT:
1611     if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, Known, TLO,
1612                              Depth + 1))
1613       return true;
1614     if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, Known2, TLO,
1615                              Depth + 1))
1616       return true;
1617     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1618     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
1619 
1620     // If the operands are constants, see if we can simplify them.
1621     if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1622       return true;
1623 
1624     // Only known if known in both the LHS and RHS.
1625     Known = KnownBits::commonBits(Known, Known2);
1626     break;
1627   case ISD::VSELECT:
1628     if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1629                              Known, TLO, Depth + 1))
1630       return true;
1631     if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1632                              Known2, TLO, Depth + 1))
1633       return true;
1634     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1635     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
1636 
1637     // Only known if known in both the LHS and RHS.
1638     Known = KnownBits::commonBits(Known, Known2);
1639     break;
1640   case ISD::SELECT_CC:
1641     if (SimplifyDemandedBits(Op.getOperand(3), DemandedBits, Known, TLO,
1642                              Depth + 1))
1643       return true;
1644     if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, Known2, TLO,
1645                              Depth + 1))
1646       return true;
1647     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1648     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
1649 
1650     // If the operands are constants, see if we can simplify them.
1651     if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1652       return true;
1653 
1654     // Only known if known in both the LHS and RHS.
1655     Known = KnownBits::commonBits(Known, Known2);
1656     break;
1657   case ISD::SETCC: {
1658     SDValue Op0 = Op.getOperand(0);
1659     SDValue Op1 = Op.getOperand(1);
1660     ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
1661     // If (1) we only need the sign-bit, (2) the setcc operands are the same
1662     // width as the setcc result, and (3) the result of a setcc conforms to 0 or
1663     // -1, we may be able to bypass the setcc.
1664     if (DemandedBits.isSignMask() &&
1665         Op0.getScalarValueSizeInBits() == BitWidth &&
1666         getBooleanContents(Op0.getValueType()) ==
1667             BooleanContent::ZeroOrNegativeOneBooleanContent) {
1668       // If we're testing X < 0, then this compare isn't needed - just use X!
1669       // FIXME: We're limiting to integer types here, but this should also work
1670       // if we don't care about FP signed-zero. The use of SETLT with FP means
1671       // that we don't care about NaNs.
1672       if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
1673           (isNullConstant(Op1) || ISD::isBuildVectorAllZeros(Op1.getNode())))
1674         return TLO.CombineTo(Op, Op0);
1675 
1676       // TODO: Should we check for other forms of sign-bit comparisons?
1677       // Examples: X <= -1, X >= 0
1678     }
1679     if (getBooleanContents(Op0.getValueType()) ==
1680             TargetLowering::ZeroOrOneBooleanContent &&
1681         BitWidth > 1)
1682       Known.Zero.setBitsFrom(1);
1683     break;
1684   }
1685   case ISD::SHL: {
1686     SDValue Op0 = Op.getOperand(0);
1687     SDValue Op1 = Op.getOperand(1);
1688     EVT ShiftVT = Op1.getValueType();
1689 
1690     if (const APInt *SA =
1691             TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
1692       unsigned ShAmt = SA->getZExtValue();
1693       if (ShAmt == 0)
1694         return TLO.CombineTo(Op, Op0);
1695 
1696       // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
1697       // single shift.  We can do this if the bottom bits (which are shifted
1698       // out) are never demanded.
1699       // TODO - support non-uniform vector amounts.
1700       if (Op0.getOpcode() == ISD::SRL) {
1701         if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
1702           if (const APInt *SA2 =
1703                   TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
1704             unsigned C1 = SA2->getZExtValue();
1705             unsigned Opc = ISD::SHL;
1706             int Diff = ShAmt - C1;
1707             if (Diff < 0) {
1708               Diff = -Diff;
1709               Opc = ISD::SRL;
1710             }
1711             SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1712             return TLO.CombineTo(
1713                 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1714           }
1715         }
1716       }
1717 
1718       // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
1719       // are not demanded. This will likely allow the anyext to be folded away.
1720       // TODO - support non-uniform vector amounts.
1721       if (Op0.getOpcode() == ISD::ANY_EXTEND) {
1722         SDValue InnerOp = Op0.getOperand(0);
1723         EVT InnerVT = InnerOp.getValueType();
1724         unsigned InnerBits = InnerVT.getScalarSizeInBits();
1725         if (ShAmt < InnerBits && DemandedBits.getActiveBits() <= InnerBits &&
1726             isTypeDesirableForOp(ISD::SHL, InnerVT)) {
1727           EVT ShTy = getShiftAmountTy(InnerVT, DL);
1728           if (!APInt(BitWidth, ShAmt).isIntN(ShTy.getSizeInBits()))
1729             ShTy = InnerVT;
1730           SDValue NarrowShl =
1731               TLO.DAG.getNode(ISD::SHL, dl, InnerVT, InnerOp,
1732                               TLO.DAG.getConstant(ShAmt, dl, ShTy));
1733           return TLO.CombineTo(
1734               Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1735         }
1736 
1737         // Repeat the SHL optimization above in cases where an extension
1738         // intervenes: (shl (anyext (shr x, c1)), c2) to
1739         // (shl (anyext x), c2-c1).  This requires that the bottom c1 bits
1740         // aren't demanded (as above) and that the shifted upper c1 bits of
1741         // x aren't demanded.
1742         // TODO - support non-uniform vector amounts.
1743         if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
1744             InnerOp.hasOneUse()) {
1745           if (const APInt *SA2 =
1746                   TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
1747             unsigned InnerShAmt = SA2->getZExtValue();
1748             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
1749                 DemandedBits.getActiveBits() <=
1750                     (InnerBits - InnerShAmt + ShAmt) &&
1751                 DemandedBits.countTrailingZeros() >= ShAmt) {
1752               SDValue NewSA =
1753                   TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT);
1754               SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
1755                                                InnerOp.getOperand(0));
1756               return TLO.CombineTo(
1757                   Op, TLO.DAG.getNode(ISD::SHL, dl, VT, NewExt, NewSA));
1758             }
1759           }
1760         }
1761       }
1762 
1763       APInt InDemandedMask = DemandedBits.lshr(ShAmt);
1764       if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1765                                Depth + 1))
1766         return true;
1767       assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1768       Known.Zero <<= ShAmt;
1769       Known.One <<= ShAmt;
1770       // low bits known zero.
1771       Known.Zero.setLowBits(ShAmt);
1772 
1773       // Attempt to avoid multi-use ops if we don't need anything from them.
1774       if (!InDemandedMask.isAllOnesValue() || !DemandedElts.isAllOnesValue()) {
1775         SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1776             Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1777         if (DemandedOp0) {
1778           SDValue NewOp = TLO.DAG.getNode(ISD::SHL, dl, VT, DemandedOp0, Op1);
1779           return TLO.CombineTo(Op, NewOp);
1780         }
1781       }
1782 
1783       // Try shrinking the operation as long as the shift amount will still be
1784       // in range.
1785       if ((ShAmt < DemandedBits.getActiveBits()) &&
1786           ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1787         return true;
1788     } else {
1789       // This is a variable shift, so we can't shift the demand mask by a known
1790       // amount. But if we are not demanding high bits, then we are not
1791       // demanding those bits from the pre-shifted operand either.
1792       if (unsigned CTLZ = DemandedBits.countLeadingZeros()) {
1793         APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
1794         if (SimplifyDemandedBits(Op0, DemandedFromOp, DemandedElts, Known, TLO,
1795                                  Depth + 1)) {
1796           SDNodeFlags Flags = Op.getNode()->getFlags();
1797           if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
1798             // Disable the nsw and nuw flags. We can no longer guarantee that we
1799             // won't wrap after simplification.
1800             Flags.setNoSignedWrap(false);
1801             Flags.setNoUnsignedWrap(false);
1802             Op->setFlags(Flags);
1803           }
1804           return true;
1805         }
1806         Known.resetAll();
1807       }
1808     }
1809 
1810     // If we are only demanding sign bits then we can use the shift source
1811     // directly.
1812     if (const APInt *MaxSA =
1813             TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
1814       unsigned ShAmt = MaxSA->getZExtValue();
1815       unsigned NumSignBits =
1816           TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
1817       unsigned UpperDemandedBits = BitWidth - DemandedBits.countTrailingZeros();
1818       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
1819         return TLO.CombineTo(Op, Op0);
1820     }
1821     break;
1822   }
1823   case ISD::SRL: {
1824     SDValue Op0 = Op.getOperand(0);
1825     SDValue Op1 = Op.getOperand(1);
1826     EVT ShiftVT = Op1.getValueType();
1827 
1828     // Try to match AVG patterns.
1829     if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
1830                                         DemandedElts, Depth + 1))
1831       return TLO.CombineTo(Op, AVG);
1832 
1833     if (const APInt *SA =
1834             TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
1835       unsigned ShAmt = SA->getZExtValue();
1836       if (ShAmt == 0)
1837         return TLO.CombineTo(Op, Op0);
1838 
1839       // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
1840       // single shift.  We can do this if the top bits (which are shifted out)
1841       // are never demanded.
1842       // TODO - support non-uniform vector amounts.
1843       if (Op0.getOpcode() == ISD::SHL) {
1844         if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1845           if (const APInt *SA2 =
1846                   TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
1847             unsigned C1 = SA2->getZExtValue();
1848             unsigned Opc = ISD::SRL;
1849             int Diff = ShAmt - C1;
1850             if (Diff < 0) {
1851               Diff = -Diff;
1852               Opc = ISD::SHL;
1853             }
1854             SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1855             return TLO.CombineTo(
1856                 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1857           }
1858         }
1859       }
1860 
1861       APInt InDemandedMask = (DemandedBits << ShAmt);
1862 
1863       // If the shift is exact, then it does demand the low bits (and knows that
1864       // they are zero).
1865       if (Op->getFlags().hasExact())
1866         InDemandedMask.setLowBits(ShAmt);
1867 
1868       // Compute the new bits that are at the top now.
1869       if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1870                                Depth + 1))
1871         return true;
1872       assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1873       Known.Zero.lshrInPlace(ShAmt);
1874       Known.One.lshrInPlace(ShAmt);
1875       // High bits known zero.
1876       Known.Zero.setHighBits(ShAmt);
1877 
1878       // Attempt to avoid multi-use ops if we don't need anything from them.
1879       if (!InDemandedMask.isAllOnesValue() || !DemandedElts.isAllOnesValue()) {
1880         SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1881             Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1882         if (DemandedOp0) {
1883           SDValue NewOp = TLO.DAG.getNode(ISD::SRL, dl, VT, DemandedOp0, Op1);
1884           return TLO.CombineTo(Op, NewOp);
1885         }
1886       }
1887     }
1888     break;
1889   }
1890   case ISD::SRA: {
1891     SDValue Op0 = Op.getOperand(0);
1892     SDValue Op1 = Op.getOperand(1);
1893     EVT ShiftVT = Op1.getValueType();
1894 
1895     // If we only want bits that already match the signbit then we don't need
1896     // to shift.
1897     unsigned NumHiDemandedBits = BitWidth - DemandedBits.countTrailingZeros();
1898     if (TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1) >=
1899         NumHiDemandedBits)
1900       return TLO.CombineTo(Op, Op0);
1901 
1902     // If this is an arithmetic shift right and only the low-bit is set, we can
1903     // always convert this into a logical shr, even if the shift amount is
1904     // variable.  The low bit of the shift cannot be an input sign bit unless
1905     // the shift amount is >= the size of the datatype, which is undefined.
1906     if (DemandedBits.isOne())
1907       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
1908 
1909     // Try to match AVG patterns.
1910     if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
1911                                         DemandedElts, Depth + 1))
1912       return TLO.CombineTo(Op, AVG);
1913 
1914     if (const APInt *SA =
1915             TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
1916       unsigned ShAmt = SA->getZExtValue();
1917       if (ShAmt == 0)
1918         return TLO.CombineTo(Op, Op0);
1919 
1920       APInt InDemandedMask = (DemandedBits << ShAmt);
1921 
1922       // If the shift is exact, then it does demand the low bits (and knows that
1923       // they are zero).
1924       if (Op->getFlags().hasExact())
1925         InDemandedMask.setLowBits(ShAmt);
1926 
1927       // If any of the demanded bits are produced by the sign extension, we also
1928       // demand the input sign bit.
1929       if (DemandedBits.countLeadingZeros() < ShAmt)
1930         InDemandedMask.setSignBit();
1931 
1932       if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1933                                Depth + 1))
1934         return true;
1935       assert(!Known.hasConflict() && "Bits known to be one AND zero?");
1936       Known.Zero.lshrInPlace(ShAmt);
1937       Known.One.lshrInPlace(ShAmt);
1938 
1939       // If the input sign bit is known to be zero, or if none of the top bits
1940       // are demanded, turn this into an unsigned shift right.
1941       if (Known.Zero[BitWidth - ShAmt - 1] ||
1942           DemandedBits.countLeadingZeros() >= ShAmt) {
1943         SDNodeFlags Flags;
1944         Flags.setExact(Op->getFlags().hasExact());
1945         return TLO.CombineTo(
1946             Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1, Flags));
1947       }
1948 
1949       int Log2 = DemandedBits.exactLogBase2();
1950       if (Log2 >= 0) {
1951         // The bit must come from the sign.
1952         SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT);
1953         return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA));
1954       }
1955 
1956       if (Known.One[BitWidth - ShAmt - 1])
1957         // New bits are known one.
1958         Known.One.setHighBits(ShAmt);
1959 
1960       // Attempt to avoid multi-use ops if we don't need anything from them.
1961       if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
1962         SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1963             Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1964         if (DemandedOp0) {
1965           SDValue NewOp = TLO.DAG.getNode(ISD::SRA, dl, VT, DemandedOp0, Op1);
1966           return TLO.CombineTo(Op, NewOp);
1967         }
1968       }
1969     }
1970     break;
1971   }
1972   case ISD::FSHL:
1973   case ISD::FSHR: {
1974     SDValue Op0 = Op.getOperand(0);
1975     SDValue Op1 = Op.getOperand(1);
1976     SDValue Op2 = Op.getOperand(2);
1977     bool IsFSHL = (Op.getOpcode() == ISD::FSHL);
1978 
1979     if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) {
1980       unsigned Amt = SA->getAPIntValue().urem(BitWidth);
1981 
1982       // For fshl, 0-shift returns the 1st arg.
1983       // For fshr, 0-shift returns the 2nd arg.
1984       if (Amt == 0) {
1985         if (SimplifyDemandedBits(IsFSHL ? Op0 : Op1, DemandedBits, DemandedElts,
1986                                  Known, TLO, Depth + 1))
1987           return true;
1988         break;
1989       }
1990 
1991       // fshl: (Op0 << Amt) | (Op1 >> (BW - Amt))
1992       // fshr: (Op0 << (BW - Amt)) | (Op1 >> Amt)
1993       APInt Demanded0 = DemandedBits.lshr(IsFSHL ? Amt : (BitWidth - Amt));
1994       APInt Demanded1 = DemandedBits << (IsFSHL ? (BitWidth - Amt) : Amt);
1995       if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
1996                                Depth + 1))
1997         return true;
1998       if (SimplifyDemandedBits(Op1, Demanded1, DemandedElts, Known, TLO,
1999                                Depth + 1))
2000         return true;
2001 
2002       Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt));
2003       Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt));
2004       Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2005       Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2006       Known.One |= Known2.One;
2007       Known.Zero |= Known2.Zero;
2008 
2009       // Attempt to avoid multi-use ops if we don't need anything from them.
2010       if (!Demanded0.isAllOnes() || !Demanded1.isAllOnes() ||
2011           !DemandedElts.isAllOnes()) {
2012         SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2013             Op0, Demanded0, DemandedElts, TLO.DAG, Depth + 1);
2014         SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2015             Op1, Demanded1, DemandedElts, TLO.DAG, Depth + 1);
2016         if (DemandedOp0 || DemandedOp1) {
2017           DemandedOp0 = DemandedOp0 ? DemandedOp0 : Op0;
2018           DemandedOp1 = DemandedOp1 ? DemandedOp1 : Op1;
2019           SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedOp0,
2020                                           DemandedOp1, Op2);
2021           return TLO.CombineTo(Op, NewOp);
2022         }
2023       }
2024     }
2025 
2026     // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2027     if (isPowerOf2_32(BitWidth)) {
2028       APInt DemandedAmtBits(Op2.getScalarValueSizeInBits(), BitWidth - 1);
2029       if (SimplifyDemandedBits(Op2, DemandedAmtBits, DemandedElts,
2030                                Known2, TLO, Depth + 1))
2031         return true;
2032     }
2033     break;
2034   }
2035   case ISD::ROTL:
2036   case ISD::ROTR: {
2037     SDValue Op0 = Op.getOperand(0);
2038     SDValue Op1 = Op.getOperand(1);
2039     bool IsROTL = (Op.getOpcode() == ISD::ROTL);
2040 
2041     // If we're rotating an 0/-1 value, then it stays an 0/-1 value.
2042     if (BitWidth == TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1))
2043       return TLO.CombineTo(Op, Op0);
2044 
2045     if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
2046       unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2047       unsigned RevAmt = BitWidth - Amt;
2048 
2049       // rotl: (Op0 << Amt) | (Op0 >> (BW - Amt))
2050       // rotr: (Op0 << (BW - Amt)) | (Op0 >> Amt)
2051       APInt Demanded0 = DemandedBits.rotr(IsROTL ? Amt : RevAmt);
2052       if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2053                                Depth + 1))
2054         return true;
2055 
2056       // rot*(x, 0) --> x
2057       if (Amt == 0)
2058         return TLO.CombineTo(Op, Op0);
2059 
2060       // See if we don't demand either half of the rotated bits.
2061       if ((!TLO.LegalOperations() || isOperationLegal(ISD::SHL, VT)) &&
2062           DemandedBits.countTrailingZeros() >= (IsROTL ? Amt : RevAmt)) {
2063         Op1 = TLO.DAG.getConstant(IsROTL ? Amt : RevAmt, dl, Op1.getValueType());
2064         return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, Op1));
2065       }
2066       if ((!TLO.LegalOperations() || isOperationLegal(ISD::SRL, VT)) &&
2067           DemandedBits.countLeadingZeros() >= (IsROTL ? RevAmt : Amt)) {
2068         Op1 = TLO.DAG.getConstant(IsROTL ? RevAmt : Amt, dl, Op1.getValueType());
2069         return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2070       }
2071     }
2072 
2073     // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2074     if (isPowerOf2_32(BitWidth)) {
2075       APInt DemandedAmtBits(Op1.getScalarValueSizeInBits(), BitWidth - 1);
2076       if (SimplifyDemandedBits(Op1, DemandedAmtBits, DemandedElts, Known2, TLO,
2077                                Depth + 1))
2078         return true;
2079     }
2080     break;
2081   }
2082   case ISD::UMIN: {
2083     // Check if one arg is always less than (or equal) to the other arg.
2084     SDValue Op0 = Op.getOperand(0);
2085     SDValue Op1 = Op.getOperand(1);
2086     KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2087     KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2088     Known = KnownBits::umin(Known0, Known1);
2089     if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2090       return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2091     if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2092       return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2093     break;
2094   }
2095   case ISD::UMAX: {
2096     // Check if one arg is always greater than (or equal) to the other arg.
2097     SDValue Op0 = Op.getOperand(0);
2098     SDValue Op1 = Op.getOperand(1);
2099     KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2100     KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2101     Known = KnownBits::umax(Known0, Known1);
2102     if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2103       return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2104     if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2105       return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2106     break;
2107   }
2108   case ISD::BITREVERSE: {
2109     SDValue Src = Op.getOperand(0);
2110     APInt DemandedSrcBits = DemandedBits.reverseBits();
2111     if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2112                              Depth + 1))
2113       return true;
2114     Known.One = Known2.One.reverseBits();
2115     Known.Zero = Known2.Zero.reverseBits();
2116     break;
2117   }
2118   case ISD::BSWAP: {
2119     SDValue Src = Op.getOperand(0);
2120 
2121     // If the only bits demanded come from one byte of the bswap result,
2122     // just shift the input byte into position to eliminate the bswap.
2123     unsigned NLZ = DemandedBits.countLeadingZeros();
2124     unsigned NTZ = DemandedBits.countTrailingZeros();
2125 
2126     // Round NTZ down to the next byte.  If we have 11 trailing zeros, then
2127     // we need all the bits down to bit 8.  Likewise, round NLZ.  If we
2128     // have 14 leading zeros, round to 8.
2129     NLZ = alignDown(NLZ, 8);
2130     NTZ = alignDown(NTZ, 8);
2131     // If we need exactly one byte, we can do this transformation.
2132     if (BitWidth - NLZ - NTZ == 8) {
2133       // Replace this with either a left or right shift to get the byte into
2134       // the right place.
2135       unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
2136       if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
2137         EVT ShiftAmtTy = getShiftAmountTy(VT, DL);
2138         unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
2139         SDValue ShAmt = TLO.DAG.getConstant(ShiftAmount, dl, ShiftAmtTy);
2140         SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
2141         return TLO.CombineTo(Op, NewOp);
2142       }
2143     }
2144 
2145     APInt DemandedSrcBits = DemandedBits.byteSwap();
2146     if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2147                              Depth + 1))
2148       return true;
2149     Known.One = Known2.One.byteSwap();
2150     Known.Zero = Known2.Zero.byteSwap();
2151     break;
2152   }
2153   case ISD::CTPOP: {
2154     // If only 1 bit is demanded, replace with PARITY as long as we're before
2155     // op legalization.
2156     // FIXME: Limit to scalars for now.
2157     if (DemandedBits.isOne() && !TLO.LegalOps && !VT.isVector())
2158       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::PARITY, dl, VT,
2159                                                Op.getOperand(0)));
2160 
2161     Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2162     break;
2163   }
2164   case ISD::SIGN_EXTEND_INREG: {
2165     SDValue Op0 = Op.getOperand(0);
2166     EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2167     unsigned ExVTBits = ExVT.getScalarSizeInBits();
2168 
2169     // If we only care about the highest bit, don't bother shifting right.
2170     if (DemandedBits.isSignMask()) {
2171       unsigned MinSignedBits =
2172           TLO.DAG.ComputeMaxSignificantBits(Op0, DemandedElts, Depth + 1);
2173       bool AlreadySignExtended = ExVTBits >= MinSignedBits;
2174       // However if the input is already sign extended we expect the sign
2175       // extension to be dropped altogether later and do not simplify.
2176       if (!AlreadySignExtended) {
2177         // Compute the correct shift amount type, which must be getShiftAmountTy
2178         // for scalar types after legalization.
2179         SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ExVTBits, dl,
2180                                                getShiftAmountTy(VT, DL));
2181         return TLO.CombineTo(Op,
2182                              TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
2183       }
2184     }
2185 
2186     // If none of the extended bits are demanded, eliminate the sextinreg.
2187     if (DemandedBits.getActiveBits() <= ExVTBits)
2188       return TLO.CombineTo(Op, Op0);
2189 
2190     APInt InputDemandedBits = DemandedBits.getLoBits(ExVTBits);
2191 
2192     // Since the sign extended bits are demanded, we know that the sign
2193     // bit is demanded.
2194     InputDemandedBits.setBit(ExVTBits - 1);
2195 
2196     if (SimplifyDemandedBits(Op0, InputDemandedBits, DemandedElts, Known, TLO,
2197                              Depth + 1))
2198       return true;
2199     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
2200 
2201     // If the sign bit of the input is known set or clear, then we know the
2202     // top bits of the result.
2203 
2204     // If the input sign bit is known zero, convert this into a zero extension.
2205     if (Known.Zero[ExVTBits - 1])
2206       return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT));
2207 
2208     APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits);
2209     if (Known.One[ExVTBits - 1]) { // Input sign bit known set
2210       Known.One.setBitsFrom(ExVTBits);
2211       Known.Zero &= Mask;
2212     } else { // Input sign bit unknown
2213       Known.Zero &= Mask;
2214       Known.One &= Mask;
2215     }
2216     break;
2217   }
2218   case ISD::BUILD_PAIR: {
2219     EVT HalfVT = Op.getOperand(0).getValueType();
2220     unsigned HalfBitWidth = HalfVT.getScalarSizeInBits();
2221 
2222     APInt MaskLo = DemandedBits.getLoBits(HalfBitWidth).trunc(HalfBitWidth);
2223     APInt MaskHi = DemandedBits.getHiBits(HalfBitWidth).trunc(HalfBitWidth);
2224 
2225     KnownBits KnownLo, KnownHi;
2226 
2227     if (SimplifyDemandedBits(Op.getOperand(0), MaskLo, KnownLo, TLO, Depth + 1))
2228       return true;
2229 
2230     if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1))
2231       return true;
2232 
2233     Known = KnownHi.concat(KnownLo);
2234     break;
2235   }
2236   case ISD::ZERO_EXTEND_VECTOR_INREG:
2237     if (VT.isScalableVector())
2238       return false;
2239     [[fallthrough]];
2240   case ISD::ZERO_EXTEND: {
2241     SDValue Src = Op.getOperand(0);
2242     EVT SrcVT = Src.getValueType();
2243     unsigned InBits = SrcVT.getScalarSizeInBits();
2244     unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2245     bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
2246 
2247     // If none of the top bits are demanded, convert this into an any_extend.
2248     if (DemandedBits.getActiveBits() <= InBits) {
2249       // If we only need the non-extended bits of the bottom element
2250       // then we can just bitcast to the result.
2251       if (IsLE && IsVecInReg && DemandedElts == 1 &&
2252           VT.getSizeInBits() == SrcVT.getSizeInBits())
2253         return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2254 
2255       unsigned Opc =
2256           IsVecInReg ? ISD::ANY_EXTEND_VECTOR_INREG : ISD::ANY_EXTEND;
2257       if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2258         return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2259     }
2260 
2261     APInt InDemandedBits = DemandedBits.trunc(InBits);
2262     APInt InDemandedElts = DemandedElts.zext(InElts);
2263     if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2264                              Depth + 1))
2265       return true;
2266     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
2267     assert(Known.getBitWidth() == InBits && "Src width has changed?");
2268     Known = Known.zext(BitWidth);
2269 
2270     // Attempt to avoid multi-use ops if we don't need anything from them.
2271     if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2272             Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2273       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2274     break;
2275   }
2276   case ISD::SIGN_EXTEND_VECTOR_INREG:
2277     if (VT.isScalableVector())
2278       return false;
2279     [[fallthrough]];
2280   case ISD::SIGN_EXTEND: {
2281     SDValue Src = Op.getOperand(0);
2282     EVT SrcVT = Src.getValueType();
2283     unsigned InBits = SrcVT.getScalarSizeInBits();
2284     unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2285     bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
2286 
2287     // If none of the top bits are demanded, convert this into an any_extend.
2288     if (DemandedBits.getActiveBits() <= InBits) {
2289       // If we only need the non-extended bits of the bottom element
2290       // then we can just bitcast to the result.
2291       if (IsLE && IsVecInReg && DemandedElts == 1 &&
2292           VT.getSizeInBits() == SrcVT.getSizeInBits())
2293         return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2294 
2295       unsigned Opc =
2296           IsVecInReg ? ISD::ANY_EXTEND_VECTOR_INREG : ISD::ANY_EXTEND;
2297       if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2298         return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2299     }
2300 
2301     APInt InDemandedBits = DemandedBits.trunc(InBits);
2302     APInt InDemandedElts = DemandedElts.zext(InElts);
2303 
2304     // Since some of the sign extended bits are demanded, we know that the sign
2305     // bit is demanded.
2306     InDemandedBits.setBit(InBits - 1);
2307 
2308     if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2309                              Depth + 1))
2310       return true;
2311     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
2312     assert(Known.getBitWidth() == InBits && "Src width has changed?");
2313 
2314     // If the sign bit is known one, the top bits match.
2315     Known = Known.sext(BitWidth);
2316 
2317     // If the sign bit is known zero, convert this to a zero extend.
2318     if (Known.isNonNegative()) {
2319       unsigned Opc =
2320           IsVecInReg ? ISD::ZERO_EXTEND_VECTOR_INREG : ISD::ZERO_EXTEND;
2321       if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2322         return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2323     }
2324 
2325     // Attempt to avoid multi-use ops if we don't need anything from them.
2326     if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2327             Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2328       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2329     break;
2330   }
2331   case ISD::ANY_EXTEND_VECTOR_INREG:
2332     if (VT.isScalableVector())
2333       return false;
2334     [[fallthrough]];
2335   case ISD::ANY_EXTEND: {
2336     SDValue Src = Op.getOperand(0);
2337     EVT SrcVT = Src.getValueType();
2338     unsigned InBits = SrcVT.getScalarSizeInBits();
2339     unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2340     bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
2341 
2342     // If we only need the bottom element then we can just bitcast.
2343     // TODO: Handle ANY_EXTEND?
2344     if (IsLE && IsVecInReg && DemandedElts == 1 &&
2345         VT.getSizeInBits() == SrcVT.getSizeInBits())
2346       return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2347 
2348     APInt InDemandedBits = DemandedBits.trunc(InBits);
2349     APInt InDemandedElts = DemandedElts.zext(InElts);
2350     if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2351                              Depth + 1))
2352       return true;
2353     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
2354     assert(Known.getBitWidth() == InBits && "Src width has changed?");
2355     Known = Known.anyext(BitWidth);
2356 
2357     // Attempt to avoid multi-use ops if we don't need anything from them.
2358     if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2359             Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2360       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2361     break;
2362   }
2363   case ISD::TRUNCATE: {
2364     SDValue Src = Op.getOperand(0);
2365 
2366     // Simplify the input, using demanded bit information, and compute the known
2367     // zero/one bits live out.
2368     unsigned OperandBitWidth = Src.getScalarValueSizeInBits();
2369     APInt TruncMask = DemandedBits.zext(OperandBitWidth);
2370     if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, Known, TLO,
2371                              Depth + 1))
2372       return true;
2373     Known = Known.trunc(BitWidth);
2374 
2375     // Attempt to avoid multi-use ops if we don't need anything from them.
2376     if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2377             Src, TruncMask, DemandedElts, TLO.DAG, Depth + 1))
2378       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, NewSrc));
2379 
2380     // If the input is only used by this truncate, see if we can shrink it based
2381     // on the known demanded bits.
2382     switch (Src.getOpcode()) {
2383     default:
2384       break;
2385     case ISD::SRL:
2386       // Shrink SRL by a constant if none of the high bits shifted in are
2387       // demanded.
2388       if (TLO.LegalTypes() && !isTypeDesirableForOp(ISD::SRL, VT))
2389         // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
2390         // undesirable.
2391         break;
2392 
2393       if (Src.getNode()->hasOneUse()) {
2394         const APInt *ShAmtC =
2395             TLO.DAG.getValidShiftAmountConstant(Src, DemandedElts);
2396         if (!ShAmtC || ShAmtC->uge(BitWidth))
2397           break;
2398         uint64_t ShVal = ShAmtC->getZExtValue();
2399 
2400         APInt HighBits =
2401             APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
2402         HighBits.lshrInPlace(ShVal);
2403         HighBits = HighBits.trunc(BitWidth);
2404 
2405         if (!(HighBits & DemandedBits)) {
2406           // None of the shifted in bits are needed.  Add a truncate of the
2407           // shift input, then shift it.
2408           SDValue NewShAmt = TLO.DAG.getConstant(
2409               ShVal, dl, getShiftAmountTy(VT, DL, TLO.LegalTypes()));
2410           SDValue NewTrunc =
2411               TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
2412           return TLO.CombineTo(
2413               Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, NewShAmt));
2414         }
2415       }
2416       break;
2417     }
2418 
2419     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
2420     break;
2421   }
2422   case ISD::AssertZext: {
2423     // AssertZext demands all of the high bits, plus any of the low bits
2424     // demanded by its users.
2425     EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2426     APInt InMask = APInt::getLowBitsSet(BitWidth, ZVT.getSizeInBits());
2427     if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known,
2428                              TLO, Depth + 1))
2429       return true;
2430     assert(!Known.hasConflict() && "Bits known to be one AND zero?");
2431 
2432     Known.Zero |= ~InMask;
2433     Known.One &= (~Known.Zero);
2434     break;
2435   }
2436   case ISD::EXTRACT_VECTOR_ELT: {
2437     SDValue Src = Op.getOperand(0);
2438     SDValue Idx = Op.getOperand(1);
2439     ElementCount SrcEltCnt = Src.getValueType().getVectorElementCount();
2440     unsigned EltBitWidth = Src.getScalarValueSizeInBits();
2441 
2442     if (SrcEltCnt.isScalable())
2443       return false;
2444 
2445     // Demand the bits from every vector element without a constant index.
2446     unsigned NumSrcElts = SrcEltCnt.getFixedValue();
2447     APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
2448     if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx))
2449       if (CIdx->getAPIntValue().ult(NumSrcElts))
2450         DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue());
2451 
2452     // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know
2453     // anything about the extended bits.
2454     APInt DemandedSrcBits = DemandedBits;
2455     if (BitWidth > EltBitWidth)
2456       DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth);
2457 
2458     if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO,
2459                              Depth + 1))
2460       return true;
2461 
2462     // Attempt to avoid multi-use ops if we don't need anything from them.
2463     if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2464       if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2465               Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2466         SDValue NewOp =
2467             TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc, Idx);
2468         return TLO.CombineTo(Op, NewOp);
2469       }
2470     }
2471 
2472     Known = Known2;
2473     if (BitWidth > EltBitWidth)
2474       Known = Known.anyext(BitWidth);
2475     break;
2476   }
2477   case ISD::BITCAST: {
2478     if (VT.isScalableVector())
2479       return false;
2480     SDValue Src = Op.getOperand(0);
2481     EVT SrcVT = Src.getValueType();
2482     unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
2483 
2484     // If this is an FP->Int bitcast and if the sign bit is the only
2485     // thing demanded, turn this into a FGETSIGN.
2486     if (!TLO.LegalOperations() && !VT.isVector() && !SrcVT.isVector() &&
2487         DemandedBits == APInt::getSignMask(Op.getValueSizeInBits()) &&
2488         SrcVT.isFloatingPoint()) {
2489       bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, VT);
2490       bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32);
2491       if ((OpVTLegal || i32Legal) && VT.isSimple() && SrcVT != MVT::f16 &&
2492           SrcVT != MVT::f128) {
2493         // Cannot eliminate/lower SHL for f128 yet.
2494         EVT Ty = OpVTLegal ? VT : MVT::i32;
2495         // Make a FGETSIGN + SHL to move the sign bit into the appropriate
2496         // place.  We expect the SHL to be eliminated by other optimizations.
2497         SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Src);
2498         unsigned OpVTSizeInBits = Op.getValueSizeInBits();
2499         if (!OpVTLegal && OpVTSizeInBits > 32)
2500           Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Sign);
2501         unsigned ShVal = Op.getValueSizeInBits() - 1;
2502         SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, VT);
2503         return TLO.CombineTo(Op,
2504                              TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
2505       }
2506     }
2507 
2508     // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
2509     // Demand the elt/bit if any of the original elts/bits are demanded.
2510     if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
2511       unsigned Scale = BitWidth / NumSrcEltBits;
2512       unsigned NumSrcElts = SrcVT.getVectorNumElements();
2513       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2514       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2515       for (unsigned i = 0; i != Scale; ++i) {
2516         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
2517         unsigned BitOffset = EltOffset * NumSrcEltBits;
2518         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
2519         if (!Sub.isZero()) {
2520           DemandedSrcBits |= Sub;
2521           for (unsigned j = 0; j != NumElts; ++j)
2522             if (DemandedElts[j])
2523               DemandedSrcElts.setBit((j * Scale) + i);
2524         }
2525       }
2526 
2527       APInt KnownSrcUndef, KnownSrcZero;
2528       if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2529                                      KnownSrcZero, TLO, Depth + 1))
2530         return true;
2531 
2532       KnownBits KnownSrcBits;
2533       if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2534                                KnownSrcBits, TLO, Depth + 1))
2535         return true;
2536     } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
2537       // TODO - bigendian once we have test coverage.
2538       unsigned Scale = NumSrcEltBits / BitWidth;
2539       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
2540       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2541       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2542       for (unsigned i = 0; i != NumElts; ++i)
2543         if (DemandedElts[i]) {
2544           unsigned Offset = (i % Scale) * BitWidth;
2545           DemandedSrcBits.insertBits(DemandedBits, Offset);
2546           DemandedSrcElts.setBit(i / Scale);
2547         }
2548 
2549       if (SrcVT.isVector()) {
2550         APInt KnownSrcUndef, KnownSrcZero;
2551         if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2552                                        KnownSrcZero, TLO, Depth + 1))
2553           return true;
2554       }
2555 
2556       KnownBits KnownSrcBits;
2557       if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2558                                KnownSrcBits, TLO, Depth + 1))
2559         return true;
2560     }
2561 
2562     // If this is a bitcast, let computeKnownBits handle it.  Only do this on a
2563     // recursive call where Known may be useful to the caller.
2564     if (Depth > 0) {
2565       Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2566       return false;
2567     }
2568     break;
2569   }
2570   case ISD::MUL:
2571     if (DemandedBits.isPowerOf2()) {
2572       // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
2573       // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
2574       // odd (has LSB set), then the left-shifted low bit of X is the answer.
2575       unsigned CTZ = DemandedBits.countTrailingZeros();
2576       ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
2577       if (C && C->getAPIntValue().countTrailingZeros() == CTZ) {
2578         EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
2579         SDValue AmtC = TLO.DAG.getConstant(CTZ, dl, ShiftAmtTy);
2580         SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
2581         return TLO.CombineTo(Op, Shl);
2582       }
2583     }
2584     // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
2585     // X * X is odd iff X is odd.
2586     // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
2587     if (Op.getOperand(0) == Op.getOperand(1) && DemandedBits.ult(4)) {
2588       SDValue One = TLO.DAG.getConstant(1, dl, VT);
2589       SDValue And1 = TLO.DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), One);
2590       return TLO.CombineTo(Op, And1);
2591     }
2592     [[fallthrough]];
2593   case ISD::ADD:
2594   case ISD::SUB: {
2595     // Add, Sub, and Mul don't demand any bits in positions beyond that
2596     // of the highest bit demanded of them.
2597     SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1);
2598     SDNodeFlags Flags = Op.getNode()->getFlags();
2599     unsigned DemandedBitsLZ = DemandedBits.countLeadingZeros();
2600     APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ);
2601     if (SimplifyDemandedBits(Op0, LoMask, DemandedElts, Known2, TLO,
2602                              Depth + 1) ||
2603         SimplifyDemandedBits(Op1, LoMask, DemandedElts, Known2, TLO,
2604                              Depth + 1) ||
2605         // See if the operation should be performed at a smaller bit width.
2606         ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) {
2607       if (Flags.hasNoSignedWrap() || Flags.hasNoUnsignedWrap()) {
2608         // Disable the nsw and nuw flags. We can no longer guarantee that we
2609         // won't wrap after simplification.
2610         Flags.setNoSignedWrap(false);
2611         Flags.setNoUnsignedWrap(false);
2612         Op->setFlags(Flags);
2613       }
2614       return true;
2615     }
2616 
2617     // neg x with only low bit demanded is simply x.
2618     if (Op.getOpcode() == ISD::SUB && DemandedBits.isOne() &&
2619         isa<ConstantSDNode>(Op0) && cast<ConstantSDNode>(Op0)->isZero())
2620       return TLO.CombineTo(Op, Op1);
2621 
2622     // Attempt to avoid multi-use ops if we don't need anything from them.
2623     if (!LoMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2624       SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2625           Op0, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2626       SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2627           Op1, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2628       if (DemandedOp0 || DemandedOp1) {
2629         Flags.setNoSignedWrap(false);
2630         Flags.setNoUnsignedWrap(false);
2631         Op0 = DemandedOp0 ? DemandedOp0 : Op0;
2632         Op1 = DemandedOp1 ? DemandedOp1 : Op1;
2633         SDValue NewOp =
2634             TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1, Flags);
2635         return TLO.CombineTo(Op, NewOp);
2636       }
2637     }
2638 
2639     // If we have a constant operand, we may be able to turn it into -1 if we
2640     // do not demand the high bits. This can make the constant smaller to
2641     // encode, allow more general folding, or match specialized instruction
2642     // patterns (eg, 'blsr' on x86). Don't bother changing 1 to -1 because that
2643     // is probably not useful (and could be detrimental).
2644     ConstantSDNode *C = isConstOrConstSplat(Op1);
2645     APInt HighMask = APInt::getHighBitsSet(BitWidth, DemandedBitsLZ);
2646     if (C && !C->isAllOnes() && !C->isOne() &&
2647         (C->getAPIntValue() | HighMask).isAllOnes()) {
2648       SDValue Neg1 = TLO.DAG.getAllOnesConstant(dl, VT);
2649       // Disable the nsw and nuw flags. We can no longer guarantee that we
2650       // won't wrap after simplification.
2651       Flags.setNoSignedWrap(false);
2652       Flags.setNoUnsignedWrap(false);
2653       SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Neg1, Flags);
2654       return TLO.CombineTo(Op, NewOp);
2655     }
2656 
2657     // Match a multiply with a disguised negated-power-of-2 and convert to a
2658     // an equivalent shift-left amount.
2659     // Example: (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2660     auto getShiftLeftAmt = [&HighMask](SDValue Mul) -> unsigned {
2661       if (Mul.getOpcode() != ISD::MUL || !Mul.hasOneUse())
2662         return 0;
2663 
2664       // Don't touch opaque constants. Also, ignore zero and power-of-2
2665       // multiplies. Those will get folded later.
2666       ConstantSDNode *MulC = isConstOrConstSplat(Mul.getOperand(1));
2667       if (MulC && !MulC->isOpaque() && !MulC->isZero() &&
2668           !MulC->getAPIntValue().isPowerOf2()) {
2669         APInt UnmaskedC = MulC->getAPIntValue() | HighMask;
2670         if (UnmaskedC.isNegatedPowerOf2())
2671           return (-UnmaskedC).logBase2();
2672       }
2673       return 0;
2674     };
2675 
2676     auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y, unsigned ShlAmt) {
2677       EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
2678       SDValue ShlAmtC = TLO.DAG.getConstant(ShlAmt, dl, ShiftAmtTy);
2679       SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
2680       SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
2681       return TLO.CombineTo(Op, Res);
2682     };
2683 
2684     if (isOperationLegalOrCustom(ISD::SHL, VT)) {
2685       if (Op.getOpcode() == ISD::ADD) {
2686         // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2687         if (unsigned ShAmt = getShiftLeftAmt(Op0))
2688           return foldMul(ISD::SUB, Op0.getOperand(0), Op1, ShAmt);
2689         // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC))
2690         if (unsigned ShAmt = getShiftLeftAmt(Op1))
2691           return foldMul(ISD::SUB, Op1.getOperand(0), Op0, ShAmt);
2692       }
2693       if (Op.getOpcode() == ISD::SUB) {
2694         // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC))
2695         if (unsigned ShAmt = getShiftLeftAmt(Op1))
2696           return foldMul(ISD::ADD, Op1.getOperand(0), Op0, ShAmt);
2697       }
2698     }
2699 
2700     [[fallthrough]];
2701   }
2702   default:
2703     // We also ask the target about intrinsics (which could be specific to it).
2704     if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
2705         Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
2706       // TODO: Probably okay to remove after audit; here to reduce change size
2707       // in initial enablement patch for scalable vectors
2708       if (Op.getValueType().isScalableVector())
2709         break;
2710       if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
2711                                             Known, TLO, Depth))
2712         return true;
2713       break;
2714     }
2715 
2716     // Just use computeKnownBits to compute output bits.
2717     Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2718     break;
2719   }
2720 
2721   // If we know the value of all of the demanded bits, return this as a
2722   // constant.
2723   if (!isTargetCanonicalConstantNode(Op) &&
2724       DemandedBits.isSubsetOf(Known.Zero | Known.One)) {
2725     // Avoid folding to a constant if any OpaqueConstant is involved.
2726     const SDNode *N = Op.getNode();
2727     for (SDNode *Op :
2728          llvm::make_range(SDNodeIterator::begin(N), SDNodeIterator::end(N))) {
2729       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op))
2730         if (C->isOpaque())
2731           return false;
2732     }
2733     if (VT.isInteger())
2734       return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT));
2735     if (VT.isFloatingPoint())
2736       return TLO.CombineTo(
2737           Op,
2738           TLO.DAG.getConstantFP(
2739               APFloat(TLO.DAG.EVTToAPFloatSemantics(VT), Known.One), dl, VT));
2740   }
2741 
2742   // A multi use 'all demanded elts' simplify failed to find any knownbits.
2743   // Try again just for the original demanded elts.
2744   // Ensure we do this AFTER constant folding above.
2745   if (HasMultiUse && Known.isUnknown() && !OriginalDemandedElts.isAllOnes())
2746     Known = TLO.DAG.computeKnownBits(Op, OriginalDemandedElts, Depth);
2747 
2748   return false;
2749 }
2750 
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,DAGCombinerInfo & DCI) const2751 bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
2752                                                 const APInt &DemandedElts,
2753                                                 DAGCombinerInfo &DCI) const {
2754   SelectionDAG &DAG = DCI.DAG;
2755   TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
2756                         !DCI.isBeforeLegalizeOps());
2757 
2758   APInt KnownUndef, KnownZero;
2759   bool Simplified =
2760       SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
2761   if (Simplified) {
2762     DCI.AddToWorklist(Op.getNode());
2763     DCI.CommitTargetLoweringOpt(TLO);
2764   }
2765 
2766   return Simplified;
2767 }
2768 
2769 /// Given a vector binary operation and known undefined elements for each input
2770 /// operand, compute whether each element of the output is undefined.
getKnownUndefForVectorBinop(SDValue BO,SelectionDAG & DAG,const APInt & UndefOp0,const APInt & UndefOp1)2771 static APInt getKnownUndefForVectorBinop(SDValue BO, SelectionDAG &DAG,
2772                                          const APInt &UndefOp0,
2773                                          const APInt &UndefOp1) {
2774   EVT VT = BO.getValueType();
2775   assert(DAG.getTargetLoweringInfo().isBinOp(BO.getOpcode()) && VT.isVector() &&
2776          "Vector binop only");
2777 
2778   EVT EltVT = VT.getVectorElementType();
2779   unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
2780   assert(UndefOp0.getBitWidth() == NumElts &&
2781          UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
2782 
2783   auto getUndefOrConstantElt = [&](SDValue V, unsigned Index,
2784                                    const APInt &UndefVals) {
2785     if (UndefVals[Index])
2786       return DAG.getUNDEF(EltVT);
2787 
2788     if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) {
2789       // Try hard to make sure that the getNode() call is not creating temporary
2790       // nodes. Ignore opaque integers because they do not constant fold.
2791       SDValue Elt = BV->getOperand(Index);
2792       auto *C = dyn_cast<ConstantSDNode>(Elt);
2793       if (isa<ConstantFPSDNode>(Elt) || Elt.isUndef() || (C && !C->isOpaque()))
2794         return Elt;
2795     }
2796 
2797     return SDValue();
2798   };
2799 
2800   APInt KnownUndef = APInt::getZero(NumElts);
2801   for (unsigned i = 0; i != NumElts; ++i) {
2802     // If both inputs for this element are either constant or undef and match
2803     // the element type, compute the constant/undef result for this element of
2804     // the vector.
2805     // TODO: Ideally we would use FoldConstantArithmetic() here, but that does
2806     // not handle FP constants. The code within getNode() should be refactored
2807     // to avoid the danger of creating a bogus temporary node here.
2808     SDValue C0 = getUndefOrConstantElt(BO.getOperand(0), i, UndefOp0);
2809     SDValue C1 = getUndefOrConstantElt(BO.getOperand(1), i, UndefOp1);
2810     if (C0 && C1 && C0.getValueType() == EltVT && C1.getValueType() == EltVT)
2811       if (DAG.getNode(BO.getOpcode(), SDLoc(BO), EltVT, C0, C1).isUndef())
2812         KnownUndef.setBit(i);
2813   }
2814   return KnownUndef;
2815 }
2816 
SimplifyDemandedVectorElts(SDValue Op,const APInt & OriginalDemandedElts,APInt & KnownUndef,APInt & KnownZero,TargetLoweringOpt & TLO,unsigned Depth,bool AssumeSingleUse) const2817 bool TargetLowering::SimplifyDemandedVectorElts(
2818     SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
2819     APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
2820     bool AssumeSingleUse) const {
2821   EVT VT = Op.getValueType();
2822   unsigned Opcode = Op.getOpcode();
2823   APInt DemandedElts = OriginalDemandedElts;
2824   unsigned NumElts = DemandedElts.getBitWidth();
2825   assert(VT.isVector() && "Expected vector op");
2826 
2827   KnownUndef = KnownZero = APInt::getZero(NumElts);
2828 
2829   const TargetLowering &TLI = TLO.DAG.getTargetLoweringInfo();
2830   if (!TLI.shouldSimplifyDemandedVectorElts(Op, TLO))
2831     return false;
2832 
2833   // TODO: For now we assume we know nothing about scalable vectors.
2834   if (VT.isScalableVector())
2835     return false;
2836 
2837   assert(VT.getVectorNumElements() == NumElts &&
2838          "Mask size mismatches value type element count!");
2839 
2840   // Undef operand.
2841   if (Op.isUndef()) {
2842     KnownUndef.setAllBits();
2843     return false;
2844   }
2845 
2846   // If Op has other users, assume that all elements are needed.
2847   if (!AssumeSingleUse && !Op.getNode()->hasOneUse())
2848     DemandedElts.setAllBits();
2849 
2850   // Not demanding any elements from Op.
2851   if (DemandedElts == 0) {
2852     KnownUndef.setAllBits();
2853     return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
2854   }
2855 
2856   // Limit search depth.
2857   if (Depth >= SelectionDAG::MaxRecursionDepth)
2858     return false;
2859 
2860   SDLoc DL(Op);
2861   unsigned EltSizeInBits = VT.getScalarSizeInBits();
2862   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
2863 
2864   // Helper for demanding the specified elements and all the bits of both binary
2865   // operands.
2866   auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
2867     SDValue NewOp0 = SimplifyMultipleUseDemandedVectorElts(Op0, DemandedElts,
2868                                                            TLO.DAG, Depth + 1);
2869     SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts,
2870                                                            TLO.DAG, Depth + 1);
2871     if (NewOp0 || NewOp1) {
2872       SDValue NewOp = TLO.DAG.getNode(
2873           Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0, NewOp1 ? NewOp1 : Op1);
2874       return TLO.CombineTo(Op, NewOp);
2875     }
2876     return false;
2877   };
2878 
2879   switch (Opcode) {
2880   case ISD::SCALAR_TO_VECTOR: {
2881     if (!DemandedElts[0]) {
2882       KnownUndef.setAllBits();
2883       return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
2884     }
2885     SDValue ScalarSrc = Op.getOperand(0);
2886     if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
2887       SDValue Src = ScalarSrc.getOperand(0);
2888       SDValue Idx = ScalarSrc.getOperand(1);
2889       EVT SrcVT = Src.getValueType();
2890 
2891       ElementCount SrcEltCnt = SrcVT.getVectorElementCount();
2892 
2893       if (SrcEltCnt.isScalable())
2894         return false;
2895 
2896       unsigned NumSrcElts = SrcEltCnt.getFixedValue();
2897       if (isNullConstant(Idx)) {
2898         APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0);
2899         APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts);
2900         APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts);
2901         if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
2902                                        TLO, Depth + 1))
2903           return true;
2904       }
2905     }
2906     KnownUndef.setHighBits(NumElts - 1);
2907     break;
2908   }
2909   case ISD::BITCAST: {
2910     SDValue Src = Op.getOperand(0);
2911     EVT SrcVT = Src.getValueType();
2912 
2913     // We only handle vectors here.
2914     // TODO - investigate calling SimplifyDemandedBits/ComputeKnownBits?
2915     if (!SrcVT.isVector())
2916       break;
2917 
2918     // Fast handling of 'identity' bitcasts.
2919     unsigned NumSrcElts = SrcVT.getVectorNumElements();
2920     if (NumSrcElts == NumElts)
2921       return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef,
2922                                         KnownZero, TLO, Depth + 1);
2923 
2924     APInt SrcDemandedElts, SrcZero, SrcUndef;
2925 
2926     // Bitcast from 'large element' src vector to 'small element' vector, we
2927     // must demand a source element if any DemandedElt maps to it.
2928     if ((NumElts % NumSrcElts) == 0) {
2929       unsigned Scale = NumElts / NumSrcElts;
2930       SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
2931       if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
2932                                      TLO, Depth + 1))
2933         return true;
2934 
2935       // Try calling SimplifyDemandedBits, converting demanded elts to the bits
2936       // of the large element.
2937       // TODO - bigendian once we have test coverage.
2938       if (IsLE) {
2939         unsigned SrcEltSizeInBits = SrcVT.getScalarSizeInBits();
2940         APInt SrcDemandedBits = APInt::getZero(SrcEltSizeInBits);
2941         for (unsigned i = 0; i != NumElts; ++i)
2942           if (DemandedElts[i]) {
2943             unsigned Ofs = (i % Scale) * EltSizeInBits;
2944             SrcDemandedBits.setBits(Ofs, Ofs + EltSizeInBits);
2945           }
2946 
2947         KnownBits Known;
2948         if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcDemandedElts, Known,
2949                                  TLO, Depth + 1))
2950           return true;
2951 
2952         // The bitcast has split each wide element into a number of
2953         // narrow subelements. We have just computed the Known bits
2954         // for wide elements. See if element splitting results in
2955         // some subelements being zero. Only for demanded elements!
2956         for (unsigned SubElt = 0; SubElt != Scale; ++SubElt) {
2957           if (!Known.Zero.extractBits(EltSizeInBits, SubElt * EltSizeInBits)
2958                    .isAllOnes())
2959             continue;
2960           for (unsigned SrcElt = 0; SrcElt != NumSrcElts; ++SrcElt) {
2961             unsigned Elt = Scale * SrcElt + SubElt;
2962             if (DemandedElts[Elt])
2963               KnownZero.setBit(Elt);
2964           }
2965         }
2966       }
2967 
2968       // If the src element is zero/undef then all the output elements will be -
2969       // only demanded elements are guaranteed to be correct.
2970       for (unsigned i = 0; i != NumSrcElts; ++i) {
2971         if (SrcDemandedElts[i]) {
2972           if (SrcZero[i])
2973             KnownZero.setBits(i * Scale, (i + 1) * Scale);
2974           if (SrcUndef[i])
2975             KnownUndef.setBits(i * Scale, (i + 1) * Scale);
2976         }
2977       }
2978     }
2979 
2980     // Bitcast from 'small element' src vector to 'large element' vector, we
2981     // demand all smaller source elements covered by the larger demanded element
2982     // of this vector.
2983     if ((NumSrcElts % NumElts) == 0) {
2984       unsigned Scale = NumSrcElts / NumElts;
2985       SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
2986       if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
2987                                      TLO, Depth + 1))
2988         return true;
2989 
2990       // If all the src elements covering an output element are zero/undef, then
2991       // the output element will be as well, assuming it was demanded.
2992       for (unsigned i = 0; i != NumElts; ++i) {
2993         if (DemandedElts[i]) {
2994           if (SrcZero.extractBits(Scale, i * Scale).isAllOnes())
2995             KnownZero.setBit(i);
2996           if (SrcUndef.extractBits(Scale, i * Scale).isAllOnes())
2997             KnownUndef.setBit(i);
2998         }
2999       }
3000     }
3001     break;
3002   }
3003   case ISD::BUILD_VECTOR: {
3004     // Check all elements and simplify any unused elements with UNDEF.
3005     if (!DemandedElts.isAllOnes()) {
3006       // Don't simplify BROADCASTS.
3007       if (llvm::any_of(Op->op_values(),
3008                        [&](SDValue Elt) { return Op.getOperand(0) != Elt; })) {
3009         SmallVector<SDValue, 32> Ops(Op->op_begin(), Op->op_end());
3010         bool Updated = false;
3011         for (unsigned i = 0; i != NumElts; ++i) {
3012           if (!DemandedElts[i] && !Ops[i].isUndef()) {
3013             Ops[i] = TLO.DAG.getUNDEF(Ops[0].getValueType());
3014             KnownUndef.setBit(i);
3015             Updated = true;
3016           }
3017         }
3018         if (Updated)
3019           return TLO.CombineTo(Op, TLO.DAG.getBuildVector(VT, DL, Ops));
3020       }
3021     }
3022     for (unsigned i = 0; i != NumElts; ++i) {
3023       SDValue SrcOp = Op.getOperand(i);
3024       if (SrcOp.isUndef()) {
3025         KnownUndef.setBit(i);
3026       } else if (EltSizeInBits == SrcOp.getScalarValueSizeInBits() &&
3027                  (isNullConstant(SrcOp) || isNullFPConstant(SrcOp))) {
3028         KnownZero.setBit(i);
3029       }
3030     }
3031     break;
3032   }
3033   case ISD::CONCAT_VECTORS: {
3034     EVT SubVT = Op.getOperand(0).getValueType();
3035     unsigned NumSubVecs = Op.getNumOperands();
3036     unsigned NumSubElts = SubVT.getVectorNumElements();
3037     for (unsigned i = 0; i != NumSubVecs; ++i) {
3038       SDValue SubOp = Op.getOperand(i);
3039       APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3040       APInt SubUndef, SubZero;
3041       if (SimplifyDemandedVectorElts(SubOp, SubElts, SubUndef, SubZero, TLO,
3042                                      Depth + 1))
3043         return true;
3044       KnownUndef.insertBits(SubUndef, i * NumSubElts);
3045       KnownZero.insertBits(SubZero, i * NumSubElts);
3046     }
3047 
3048     // Attempt to avoid multi-use ops if we don't need anything from them.
3049     if (!DemandedElts.isAllOnes()) {
3050       bool FoundNewSub = false;
3051       SmallVector<SDValue, 2> DemandedSubOps;
3052       for (unsigned i = 0; i != NumSubVecs; ++i) {
3053         SDValue SubOp = Op.getOperand(i);
3054         APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3055         SDValue NewSubOp = SimplifyMultipleUseDemandedVectorElts(
3056             SubOp, SubElts, TLO.DAG, Depth + 1);
3057         DemandedSubOps.push_back(NewSubOp ? NewSubOp : SubOp);
3058         FoundNewSub = NewSubOp ? true : FoundNewSub;
3059       }
3060       if (FoundNewSub) {
3061         SDValue NewOp =
3062             TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, DemandedSubOps);
3063         return TLO.CombineTo(Op, NewOp);
3064       }
3065     }
3066     break;
3067   }
3068   case ISD::INSERT_SUBVECTOR: {
3069     // Demand any elements from the subvector and the remainder from the src its
3070     // inserted into.
3071     SDValue Src = Op.getOperand(0);
3072     SDValue Sub = Op.getOperand(1);
3073     uint64_t Idx = Op.getConstantOperandVal(2);
3074     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
3075     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
3076     APInt DemandedSrcElts = DemandedElts;
3077     DemandedSrcElts.insertBits(APInt::getZero(NumSubElts), Idx);
3078 
3079     APInt SubUndef, SubZero;
3080     if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
3081                                    Depth + 1))
3082       return true;
3083 
3084     // If none of the src operand elements are demanded, replace it with undef.
3085     if (!DemandedSrcElts && !Src.isUndef())
3086       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
3087                                                TLO.DAG.getUNDEF(VT), Sub,
3088                                                Op.getOperand(2)));
3089 
3090     if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownUndef, KnownZero,
3091                                    TLO, Depth + 1))
3092       return true;
3093     KnownUndef.insertBits(SubUndef, Idx);
3094     KnownZero.insertBits(SubZero, Idx);
3095 
3096     // Attempt to avoid multi-use ops if we don't need anything from them.
3097     if (!DemandedSrcElts.isAllOnes() || !DemandedSubElts.isAllOnes()) {
3098       SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3099           Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3100       SDValue NewSub = SimplifyMultipleUseDemandedVectorElts(
3101           Sub, DemandedSubElts, TLO.DAG, Depth + 1);
3102       if (NewSrc || NewSub) {
3103         NewSrc = NewSrc ? NewSrc : Src;
3104         NewSub = NewSub ? NewSub : Sub;
3105         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3106                                         NewSub, Op.getOperand(2));
3107         return TLO.CombineTo(Op, NewOp);
3108       }
3109     }
3110     break;
3111   }
3112   case ISD::EXTRACT_SUBVECTOR: {
3113     // Offset the demanded elts by the subvector index.
3114     SDValue Src = Op.getOperand(0);
3115     if (Src.getValueType().isScalableVector())
3116       break;
3117     uint64_t Idx = Op.getConstantOperandVal(1);
3118     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3119     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
3120 
3121     APInt SrcUndef, SrcZero;
3122     if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3123                                    Depth + 1))
3124       return true;
3125     KnownUndef = SrcUndef.extractBits(NumElts, Idx);
3126     KnownZero = SrcZero.extractBits(NumElts, Idx);
3127 
3128     // Attempt to avoid multi-use ops if we don't need anything from them.
3129     if (!DemandedElts.isAllOnes()) {
3130       SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3131           Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3132       if (NewSrc) {
3133         SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3134                                         Op.getOperand(1));
3135         return TLO.CombineTo(Op, NewOp);
3136       }
3137     }
3138     break;
3139   }
3140   case ISD::INSERT_VECTOR_ELT: {
3141     SDValue Vec = Op.getOperand(0);
3142     SDValue Scl = Op.getOperand(1);
3143     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
3144 
3145     // For a legal, constant insertion index, if we don't need this insertion
3146     // then strip it, else remove it from the demanded elts.
3147     if (CIdx && CIdx->getAPIntValue().ult(NumElts)) {
3148       unsigned Idx = CIdx->getZExtValue();
3149       if (!DemandedElts[Idx])
3150         return TLO.CombineTo(Op, Vec);
3151 
3152       APInt DemandedVecElts(DemandedElts);
3153       DemandedVecElts.clearBit(Idx);
3154       if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef,
3155                                      KnownZero, TLO, Depth + 1))
3156         return true;
3157 
3158       KnownUndef.setBitVal(Idx, Scl.isUndef());
3159 
3160       KnownZero.setBitVal(Idx, isNullConstant(Scl) || isNullFPConstant(Scl));
3161       break;
3162     }
3163 
3164     APInt VecUndef, VecZero;
3165     if (SimplifyDemandedVectorElts(Vec, DemandedElts, VecUndef, VecZero, TLO,
3166                                    Depth + 1))
3167       return true;
3168     // Without knowing the insertion index we can't set KnownUndef/KnownZero.
3169     break;
3170   }
3171   case ISD::VSELECT: {
3172     SDValue Sel = Op.getOperand(0);
3173     SDValue LHS = Op.getOperand(1);
3174     SDValue RHS = Op.getOperand(2);
3175 
3176     // Try to transform the select condition based on the current demanded
3177     // elements.
3178     APInt UndefSel, UndefZero;
3179     if (SimplifyDemandedVectorElts(Sel, DemandedElts, UndefSel, UndefZero, TLO,
3180                                    Depth + 1))
3181       return true;
3182 
3183     // See if we can simplify either vselect operand.
3184     APInt DemandedLHS(DemandedElts);
3185     APInt DemandedRHS(DemandedElts);
3186     APInt UndefLHS, ZeroLHS;
3187     APInt UndefRHS, ZeroRHS;
3188     if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3189                                    Depth + 1))
3190       return true;
3191     if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3192                                    Depth + 1))
3193       return true;
3194 
3195     KnownUndef = UndefLHS & UndefRHS;
3196     KnownZero = ZeroLHS & ZeroRHS;
3197 
3198     // If we know that the selected element is always zero, we don't need the
3199     // select value element.
3200     APInt DemandedSel = DemandedElts & ~KnownZero;
3201     if (DemandedSel != DemandedElts)
3202       if (SimplifyDemandedVectorElts(Sel, DemandedSel, UndefSel, UndefZero, TLO,
3203                                      Depth + 1))
3204         return true;
3205 
3206     break;
3207   }
3208   case ISD::VECTOR_SHUFFLE: {
3209     SDValue LHS = Op.getOperand(0);
3210     SDValue RHS = Op.getOperand(1);
3211     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
3212 
3213     // Collect demanded elements from shuffle operands..
3214     APInt DemandedLHS(NumElts, 0);
3215     APInt DemandedRHS(NumElts, 0);
3216     for (unsigned i = 0; i != NumElts; ++i) {
3217       int M = ShuffleMask[i];
3218       if (M < 0 || !DemandedElts[i])
3219         continue;
3220       assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
3221       if (M < (int)NumElts)
3222         DemandedLHS.setBit(M);
3223       else
3224         DemandedRHS.setBit(M - NumElts);
3225     }
3226 
3227     // See if we can simplify either shuffle operand.
3228     APInt UndefLHS, ZeroLHS;
3229     APInt UndefRHS, ZeroRHS;
3230     if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3231                                    Depth + 1))
3232       return true;
3233     if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3234                                    Depth + 1))
3235       return true;
3236 
3237     // Simplify mask using undef elements from LHS/RHS.
3238     bool Updated = false;
3239     bool IdentityLHS = true, IdentityRHS = true;
3240     SmallVector<int, 32> NewMask(ShuffleMask);
3241     for (unsigned i = 0; i != NumElts; ++i) {
3242       int &M = NewMask[i];
3243       if (M < 0)
3244         continue;
3245       if (!DemandedElts[i] || (M < (int)NumElts && UndefLHS[M]) ||
3246           (M >= (int)NumElts && UndefRHS[M - NumElts])) {
3247         Updated = true;
3248         M = -1;
3249       }
3250       IdentityLHS &= (M < 0) || (M == (int)i);
3251       IdentityRHS &= (M < 0) || ((M - NumElts) == i);
3252     }
3253 
3254     // Update legal shuffle masks based on demanded elements if it won't reduce
3255     // to Identity which can cause premature removal of the shuffle mask.
3256     if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps) {
3257       SDValue LegalShuffle =
3258           buildLegalVectorShuffle(VT, DL, LHS, RHS, NewMask, TLO.DAG);
3259       if (LegalShuffle)
3260         return TLO.CombineTo(Op, LegalShuffle);
3261     }
3262 
3263     // Propagate undef/zero elements from LHS/RHS.
3264     for (unsigned i = 0; i != NumElts; ++i) {
3265       int M = ShuffleMask[i];
3266       if (M < 0) {
3267         KnownUndef.setBit(i);
3268       } else if (M < (int)NumElts) {
3269         if (UndefLHS[M])
3270           KnownUndef.setBit(i);
3271         if (ZeroLHS[M])
3272           KnownZero.setBit(i);
3273       } else {
3274         if (UndefRHS[M - NumElts])
3275           KnownUndef.setBit(i);
3276         if (ZeroRHS[M - NumElts])
3277           KnownZero.setBit(i);
3278       }
3279     }
3280     break;
3281   }
3282   case ISD::ANY_EXTEND_VECTOR_INREG:
3283   case ISD::SIGN_EXTEND_VECTOR_INREG:
3284   case ISD::ZERO_EXTEND_VECTOR_INREG: {
3285     APInt SrcUndef, SrcZero;
3286     SDValue Src = Op.getOperand(0);
3287     unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3288     APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts);
3289     if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3290                                    Depth + 1))
3291       return true;
3292     KnownZero = SrcZero.zextOrTrunc(NumElts);
3293     KnownUndef = SrcUndef.zextOrTrunc(NumElts);
3294 
3295     if (IsLE && Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG &&
3296         Op.getValueSizeInBits() == Src.getValueSizeInBits() &&
3297         DemandedSrcElts == 1) {
3298       // aext - if we just need the bottom element then we can bitcast.
3299       return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
3300     }
3301 
3302     if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
3303       // zext(undef) upper bits are guaranteed to be zero.
3304       if (DemandedElts.isSubsetOf(KnownUndef))
3305         return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3306       KnownUndef.clearAllBits();
3307 
3308       // zext - if we just need the bottom element then we can mask:
3309       // zext(and(x,c)) -> and(x,c') iff the zext is the only user of the and.
3310       if (IsLE && DemandedSrcElts == 1 && Src.getOpcode() == ISD::AND &&
3311           Op->isOnlyUserOf(Src.getNode()) &&
3312           Op.getValueSizeInBits() == Src.getValueSizeInBits()) {
3313         SDLoc DL(Op);
3314         EVT SrcVT = Src.getValueType();
3315         EVT SrcSVT = SrcVT.getScalarType();
3316         SmallVector<SDValue> MaskElts;
3317         MaskElts.push_back(TLO.DAG.getAllOnesConstant(DL, SrcSVT));
3318         MaskElts.append(NumSrcElts - 1, TLO.DAG.getConstant(0, DL, SrcSVT));
3319         SDValue Mask = TLO.DAG.getBuildVector(SrcVT, DL, MaskElts);
3320         if (SDValue Fold = TLO.DAG.FoldConstantArithmetic(
3321                 ISD::AND, DL, SrcVT, {Src.getOperand(1), Mask})) {
3322           Fold = TLO.DAG.getNode(ISD::AND, DL, SrcVT, Src.getOperand(0), Fold);
3323           return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Fold));
3324         }
3325       }
3326     }
3327     break;
3328   }
3329 
3330   // TODO: There are more binop opcodes that could be handled here - MIN,
3331   // MAX, saturated math, etc.
3332   case ISD::ADD: {
3333     SDValue Op0 = Op.getOperand(0);
3334     SDValue Op1 = Op.getOperand(1);
3335     if (Op0 == Op1 && Op->isOnlyUserOf(Op0.getNode())) {
3336       APInt UndefLHS, ZeroLHS;
3337       if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3338                                      Depth + 1, /*AssumeSingleUse*/ true))
3339         return true;
3340     }
3341     [[fallthrough]];
3342   }
3343   case ISD::OR:
3344   case ISD::XOR:
3345   case ISD::SUB:
3346   case ISD::FADD:
3347   case ISD::FSUB:
3348   case ISD::FMUL:
3349   case ISD::FDIV:
3350   case ISD::FREM: {
3351     SDValue Op0 = Op.getOperand(0);
3352     SDValue Op1 = Op.getOperand(1);
3353 
3354     APInt UndefRHS, ZeroRHS;
3355     if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3356                                    Depth + 1))
3357       return true;
3358     APInt UndefLHS, ZeroLHS;
3359     if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3360                                    Depth + 1))
3361       return true;
3362 
3363     KnownZero = ZeroLHS & ZeroRHS;
3364     KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS);
3365 
3366     // Attempt to avoid multi-use ops if we don't need anything from them.
3367     // TODO - use KnownUndef to relax the demandedelts?
3368     if (!DemandedElts.isAllOnes())
3369       if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3370         return true;
3371     break;
3372   }
3373   case ISD::SHL:
3374   case ISD::SRL:
3375   case ISD::SRA:
3376   case ISD::ROTL:
3377   case ISD::ROTR: {
3378     SDValue Op0 = Op.getOperand(0);
3379     SDValue Op1 = Op.getOperand(1);
3380 
3381     APInt UndefRHS, ZeroRHS;
3382     if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3383                                    Depth + 1))
3384       return true;
3385     APInt UndefLHS, ZeroLHS;
3386     if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3387                                    Depth + 1))
3388       return true;
3389 
3390     KnownZero = ZeroLHS;
3391     KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop?
3392 
3393     // Attempt to avoid multi-use ops if we don't need anything from them.
3394     // TODO - use KnownUndef to relax the demandedelts?
3395     if (!DemandedElts.isAllOnes())
3396       if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3397         return true;
3398     break;
3399   }
3400   case ISD::MUL:
3401   case ISD::MULHU:
3402   case ISD::MULHS:
3403   case ISD::AND: {
3404     SDValue Op0 = Op.getOperand(0);
3405     SDValue Op1 = Op.getOperand(1);
3406 
3407     APInt SrcUndef, SrcZero;
3408     if (SimplifyDemandedVectorElts(Op1, DemandedElts, SrcUndef, SrcZero, TLO,
3409                                    Depth + 1))
3410       return true;
3411     // If we know that a demanded element was zero in Op1 we don't need to
3412     // demand it in Op0 - its guaranteed to be zero.
3413     APInt DemandedElts0 = DemandedElts & ~SrcZero;
3414     if (SimplifyDemandedVectorElts(Op0, DemandedElts0, KnownUndef, KnownZero,
3415                                    TLO, Depth + 1))
3416       return true;
3417 
3418     KnownUndef &= DemandedElts0;
3419     KnownZero &= DemandedElts0;
3420 
3421     // If every element pair has a zero/undef then just fold to zero.
3422     // fold (and x, undef) -> 0  /  (and x, 0) -> 0
3423     // fold (mul x, undef) -> 0  /  (mul x, 0) -> 0
3424     if (DemandedElts.isSubsetOf(SrcZero | KnownZero | SrcUndef | KnownUndef))
3425       return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3426 
3427     // If either side has a zero element, then the result element is zero, even
3428     // if the other is an UNDEF.
3429     // TODO: Extend getKnownUndefForVectorBinop to also deal with known zeros
3430     // and then handle 'and' nodes with the rest of the binop opcodes.
3431     KnownZero |= SrcZero;
3432     KnownUndef &= SrcUndef;
3433     KnownUndef &= ~KnownZero;
3434 
3435     // Attempt to avoid multi-use ops if we don't need anything from them.
3436     if (!DemandedElts.isAllOnes())
3437       if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3438         return true;
3439     break;
3440   }
3441   case ISD::TRUNCATE:
3442   case ISD::SIGN_EXTEND:
3443   case ISD::ZERO_EXTEND:
3444     if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3445                                    KnownZero, TLO, Depth + 1))
3446       return true;
3447 
3448     if (Op.getOpcode() == ISD::ZERO_EXTEND) {
3449       // zext(undef) upper bits are guaranteed to be zero.
3450       if (DemandedElts.isSubsetOf(KnownUndef))
3451         return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3452       KnownUndef.clearAllBits();
3453     }
3454     break;
3455   default: {
3456     if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
3457       if (SimplifyDemandedVectorEltsForTargetNode(Op, DemandedElts, KnownUndef,
3458                                                   KnownZero, TLO, Depth))
3459         return true;
3460     } else {
3461       KnownBits Known;
3462       APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
3463       if (SimplifyDemandedBits(Op, DemandedBits, OriginalDemandedElts, Known,
3464                                TLO, Depth, AssumeSingleUse))
3465         return true;
3466     }
3467     break;
3468   }
3469   }
3470   assert((KnownUndef & KnownZero) == 0 && "Elements flagged as undef AND zero");
3471 
3472   // Constant fold all undef cases.
3473   // TODO: Handle zero cases as well.
3474   if (DemandedElts.isSubsetOf(KnownUndef))
3475     return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3476 
3477   return false;
3478 }
3479 
3480 /// Determine which of the bits specified in Mask are known to be either zero or
3481 /// one and return them in the Known.
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const3482 void TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3483                                                    KnownBits &Known,
3484                                                    const APInt &DemandedElts,
3485                                                    const SelectionDAG &DAG,
3486                                                    unsigned Depth) const {
3487   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3488           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3489           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3490           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3491          "Should use MaskedValueIsZero if you don't know whether Op"
3492          " is a target node!");
3493   Known.resetAll();
3494 }
3495 
computeKnownBitsForTargetInstr(GISelKnownBits & Analysis,Register R,KnownBits & Known,const APInt & DemandedElts,const MachineRegisterInfo & MRI,unsigned Depth) const3496 void TargetLowering::computeKnownBitsForTargetInstr(
3497     GISelKnownBits &Analysis, Register R, KnownBits &Known,
3498     const APInt &DemandedElts, const MachineRegisterInfo &MRI,
3499     unsigned Depth) const {
3500   Known.resetAll();
3501 }
3502 
computeKnownBitsForFrameIndex(const int FrameIdx,KnownBits & Known,const MachineFunction & MF) const3503 void TargetLowering::computeKnownBitsForFrameIndex(
3504   const int FrameIdx, KnownBits &Known, const MachineFunction &MF) const {
3505   // The low bits are known zero if the pointer is aligned.
3506   Known.Zero.setLowBits(Log2(MF.getFrameInfo().getObjectAlign(FrameIdx)));
3507 }
3508 
computeKnownAlignForTargetInstr(GISelKnownBits & Analysis,Register R,const MachineRegisterInfo & MRI,unsigned Depth) const3509 Align TargetLowering::computeKnownAlignForTargetInstr(
3510   GISelKnownBits &Analysis, Register R, const MachineRegisterInfo &MRI,
3511   unsigned Depth) const {
3512   return Align(1);
3513 }
3514 
3515 /// This method can be implemented by targets that want to expose additional
3516 /// information about sign bits to the DAG Combiner.
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt &,const SelectionDAG &,unsigned Depth) const3517 unsigned TargetLowering::ComputeNumSignBitsForTargetNode(SDValue Op,
3518                                                          const APInt &,
3519                                                          const SelectionDAG &,
3520                                                          unsigned Depth) const {
3521   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3522           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3523           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3524           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3525          "Should use ComputeNumSignBits if you don't know whether Op"
3526          " is a target node!");
3527   return 1;
3528 }
3529 
computeNumSignBitsForTargetInstr(GISelKnownBits & Analysis,Register R,const APInt & DemandedElts,const MachineRegisterInfo & MRI,unsigned Depth) const3530 unsigned TargetLowering::computeNumSignBitsForTargetInstr(
3531   GISelKnownBits &Analysis, Register R, const APInt &DemandedElts,
3532   const MachineRegisterInfo &MRI, unsigned Depth) const {
3533   return 1;
3534 }
3535 
SimplifyDemandedVectorEltsForTargetNode(SDValue Op,const APInt & DemandedElts,APInt & KnownUndef,APInt & KnownZero,TargetLoweringOpt & TLO,unsigned Depth) const3536 bool TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
3537     SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
3538     TargetLoweringOpt &TLO, unsigned Depth) const {
3539   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3540           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3541           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3542           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3543          "Should use SimplifyDemandedVectorElts if you don't know whether Op"
3544          " is a target node!");
3545   return false;
3546 }
3547 
SimplifyDemandedBitsForTargetNode(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth) const3548 bool TargetLowering::SimplifyDemandedBitsForTargetNode(
3549     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3550     KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
3551   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3552           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3553           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3554           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3555          "Should use SimplifyDemandedBits if you don't know whether Op"
3556          " is a target node!");
3557   computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
3558   return false;
3559 }
3560 
SimplifyMultipleUseDemandedBitsForTargetNode(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const3561 SDValue TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
3562     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3563     SelectionDAG &DAG, unsigned Depth) const {
3564   assert(
3565       (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3566        Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3567        Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3568        Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3569       "Should use SimplifyMultipleUseDemandedBits if you don't know whether Op"
3570       " is a target node!");
3571   return SDValue();
3572 }
3573 
3574 SDValue
buildLegalVectorShuffle(EVT VT,const SDLoc & DL,SDValue N0,SDValue N1,MutableArrayRef<int> Mask,SelectionDAG & DAG) const3575 TargetLowering::buildLegalVectorShuffle(EVT VT, const SDLoc &DL, SDValue N0,
3576                                         SDValue N1, MutableArrayRef<int> Mask,
3577                                         SelectionDAG &DAG) const {
3578   bool LegalMask = isShuffleMaskLegal(Mask, VT);
3579   if (!LegalMask) {
3580     std::swap(N0, N1);
3581     ShuffleVectorSDNode::commuteMask(Mask);
3582     LegalMask = isShuffleMaskLegal(Mask, VT);
3583   }
3584 
3585   if (!LegalMask)
3586     return SDValue();
3587 
3588   return DAG.getVectorShuffle(VT, DL, N0, N1, Mask);
3589 }
3590 
getTargetConstantFromLoad(LoadSDNode *) const3591 const Constant *TargetLowering::getTargetConstantFromLoad(LoadSDNode*) const {
3592   return nullptr;
3593 }
3594 
isGuaranteedNotToBeUndefOrPoisonForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool PoisonOnly,unsigned Depth) const3595 bool TargetLowering::isGuaranteedNotToBeUndefOrPoisonForTargetNode(
3596     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3597     bool PoisonOnly, unsigned Depth) const {
3598   assert(
3599       (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3600        Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3601        Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3602        Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3603       "Should use isGuaranteedNotToBeUndefOrPoison if you don't know whether Op"
3604       " is a target node!");
3605   return false;
3606 }
3607 
canCreateUndefOrPoisonForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool PoisonOnly,bool ConsiderFlags,unsigned Depth) const3608 bool TargetLowering::canCreateUndefOrPoisonForTargetNode(
3609     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
3610     bool PoisonOnly, bool ConsiderFlags, unsigned Depth) const {
3611   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3612           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3613           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3614           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3615          "Should use canCreateUndefOrPoison if you don't know whether Op"
3616          " is a target node!");
3617   // Be conservative and return true.
3618   return true;
3619 }
3620 
isKnownNeverNaNForTargetNode(SDValue Op,const SelectionDAG & DAG,bool SNaN,unsigned Depth) const3621 bool TargetLowering::isKnownNeverNaNForTargetNode(SDValue Op,
3622                                                   const SelectionDAG &DAG,
3623                                                   bool SNaN,
3624                                                   unsigned Depth) const {
3625   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3626           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3627           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3628           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3629          "Should use isKnownNeverNaN if you don't know whether Op"
3630          " is a target node!");
3631   return false;
3632 }
3633 
isSplatValueForTargetNode(SDValue Op,const APInt & DemandedElts,APInt & UndefElts,const SelectionDAG & DAG,unsigned Depth) const3634 bool TargetLowering::isSplatValueForTargetNode(SDValue Op,
3635                                                const APInt &DemandedElts,
3636                                                APInt &UndefElts,
3637                                                const SelectionDAG &DAG,
3638                                                unsigned Depth) const {
3639   assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3640           Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3641           Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3642           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3643          "Should use isSplatValue if you don't know whether Op"
3644          " is a target node!");
3645   return false;
3646 }
3647 
3648 // FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must
3649 // work with truncating build vectors and vectors with elements of less than
3650 // 8 bits.
isConstTrueVal(SDValue N) const3651 bool TargetLowering::isConstTrueVal(SDValue N) const {
3652   if (!N)
3653     return false;
3654 
3655   unsigned EltWidth;
3656   APInt CVal;
3657   if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false,
3658                                                /*AllowTruncation=*/true)) {
3659     CVal = CN->getAPIntValue();
3660     EltWidth = N.getValueType().getScalarSizeInBits();
3661   } else
3662     return false;
3663 
3664   // If this is a truncating splat, truncate the splat value.
3665   // Otherwise, we may fail to match the expected values below.
3666   if (EltWidth < CVal.getBitWidth())
3667     CVal = CVal.trunc(EltWidth);
3668 
3669   switch (getBooleanContents(N.getValueType())) {
3670   case UndefinedBooleanContent:
3671     return CVal[0];
3672   case ZeroOrOneBooleanContent:
3673     return CVal.isOne();
3674   case ZeroOrNegativeOneBooleanContent:
3675     return CVal.isAllOnes();
3676   }
3677 
3678   llvm_unreachable("Invalid boolean contents");
3679 }
3680 
isConstFalseVal(SDValue N) const3681 bool TargetLowering::isConstFalseVal(SDValue N) const {
3682   if (!N)
3683     return false;
3684 
3685   const ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N);
3686   if (!CN) {
3687     const BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N);
3688     if (!BV)
3689       return false;
3690 
3691     // Only interested in constant splats, we don't care about undef
3692     // elements in identifying boolean constants and getConstantSplatNode
3693     // returns NULL if all ops are undef;
3694     CN = BV->getConstantSplatNode();
3695     if (!CN)
3696       return false;
3697   }
3698 
3699   if (getBooleanContents(N->getValueType(0)) == UndefinedBooleanContent)
3700     return !CN->getAPIntValue()[0];
3701 
3702   return CN->isZero();
3703 }
3704 
isExtendedTrueVal(const ConstantSDNode * N,EVT VT,bool SExt) const3705 bool TargetLowering::isExtendedTrueVal(const ConstantSDNode *N, EVT VT,
3706                                        bool SExt) const {
3707   if (VT == MVT::i1)
3708     return N->isOne();
3709 
3710   TargetLowering::BooleanContent Cnt = getBooleanContents(VT);
3711   switch (Cnt) {
3712   case TargetLowering::ZeroOrOneBooleanContent:
3713     // An extended value of 1 is always true, unless its original type is i1,
3714     // in which case it will be sign extended to -1.
3715     return (N->isOne() && !SExt) || (SExt && (N->getValueType(0) != MVT::i1));
3716   case TargetLowering::UndefinedBooleanContent:
3717   case TargetLowering::ZeroOrNegativeOneBooleanContent:
3718     return N->isAllOnes() && SExt;
3719   }
3720   llvm_unreachable("Unexpected enumeration.");
3721 }
3722 
3723 /// This helper function of SimplifySetCC tries to optimize the comparison when
3724 /// either operand of the SetCC node is a bitwise-and instruction.
foldSetCCWithAnd(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,DAGCombinerInfo & DCI) const3725 SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1,
3726                                          ISD::CondCode Cond, const SDLoc &DL,
3727                                          DAGCombinerInfo &DCI) const {
3728   if (N1.getOpcode() == ISD::AND && N0.getOpcode() != ISD::AND)
3729     std::swap(N0, N1);
3730 
3731   SelectionDAG &DAG = DCI.DAG;
3732   EVT OpVT = N0.getValueType();
3733   if (N0.getOpcode() != ISD::AND || !OpVT.isInteger() ||
3734       (Cond != ISD::SETEQ && Cond != ISD::SETNE))
3735     return SDValue();
3736 
3737   // (X & Y) != 0 --> zextOrTrunc(X & Y)
3738   // iff everything but LSB is known zero:
3739   if (Cond == ISD::SETNE && isNullConstant(N1) &&
3740       (getBooleanContents(OpVT) == TargetLowering::UndefinedBooleanContent ||
3741        getBooleanContents(OpVT) == TargetLowering::ZeroOrOneBooleanContent)) {
3742     unsigned NumEltBits = OpVT.getScalarSizeInBits();
3743     APInt UpperBits = APInt::getHighBitsSet(NumEltBits, NumEltBits - 1);
3744     if (DAG.MaskedValueIsZero(N0, UpperBits))
3745       return DAG.getBoolExtOrTrunc(N0, DL, VT, OpVT);
3746   }
3747 
3748   // Try to eliminate a power-of-2 mask constant by converting to a signbit
3749   // test in a narrow type that we can truncate to with no cost. Examples:
3750   // (i32 X & 32768) == 0 --> (trunc X to i16) >= 0
3751   // (i32 X & 32768) != 0 --> (trunc X to i16) < 0
3752   // TODO: This conservatively checks for type legality on the source and
3753   //       destination types. That may inhibit optimizations, but it also
3754   //       allows setcc->shift transforms that may be more beneficial.
3755   auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
3756   if (AndC && isNullConstant(N1) && AndC->getAPIntValue().isPowerOf2() &&
3757       isTypeLegal(OpVT) && N0.hasOneUse()) {
3758     EVT NarrowVT = EVT::getIntegerVT(*DAG.getContext(),
3759                                      AndC->getAPIntValue().getActiveBits());
3760     if (isTruncateFree(OpVT, NarrowVT) && isTypeLegal(NarrowVT)) {
3761       SDValue Trunc = DAG.getZExtOrTrunc(N0.getOperand(0), DL, NarrowVT);
3762       SDValue Zero = DAG.getConstant(0, DL, NarrowVT);
3763       return DAG.getSetCC(DL, VT, Trunc, Zero,
3764                           Cond == ISD::SETEQ ? ISD::SETGE : ISD::SETLT);
3765     }
3766   }
3767 
3768   // Match these patterns in any of their permutations:
3769   // (X & Y) == Y
3770   // (X & Y) != Y
3771   SDValue X, Y;
3772   if (N0.getOperand(0) == N1) {
3773     X = N0.getOperand(1);
3774     Y = N0.getOperand(0);
3775   } else if (N0.getOperand(1) == N1) {
3776     X = N0.getOperand(0);
3777     Y = N0.getOperand(1);
3778   } else {
3779     return SDValue();
3780   }
3781 
3782   SDValue Zero = DAG.getConstant(0, DL, OpVT);
3783   if (DAG.isKnownToBeAPowerOfTwo(Y)) {
3784     // Simplify X & Y == Y to X & Y != 0 if Y has exactly one bit set.
3785     // Note that where Y is variable and is known to have at most one bit set
3786     // (for example, if it is Z & 1) we cannot do this; the expressions are not
3787     // equivalent when Y == 0.
3788     assert(OpVT.isInteger());
3789     Cond = ISD::getSetCCInverse(Cond, OpVT);
3790     if (DCI.isBeforeLegalizeOps() ||
3791         isCondCodeLegal(Cond, N0.getSimpleValueType()))
3792       return DAG.getSetCC(DL, VT, N0, Zero, Cond);
3793   } else if (N0.hasOneUse() && hasAndNotCompare(Y)) {
3794     // If the target supports an 'and-not' or 'and-complement' logic operation,
3795     // try to use that to make a comparison operation more efficient.
3796     // But don't do this transform if the mask is a single bit because there are
3797     // more efficient ways to deal with that case (for example, 'bt' on x86 or
3798     // 'rlwinm' on PPC).
3799 
3800     // Bail out if the compare operand that we want to turn into a zero is
3801     // already a zero (otherwise, infinite loop).
3802     auto *YConst = dyn_cast<ConstantSDNode>(Y);
3803     if (YConst && YConst->isZero())
3804       return SDValue();
3805 
3806     // Transform this into: ~X & Y == 0.
3807     SDValue NotX = DAG.getNOT(SDLoc(X), X, OpVT);
3808     SDValue NewAnd = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, NotX, Y);
3809     return DAG.getSetCC(DL, VT, NewAnd, Zero, Cond);
3810   }
3811 
3812   return SDValue();
3813 }
3814 
3815 /// There are multiple IR patterns that could be checking whether certain
3816 /// truncation of a signed number would be lossy or not. The pattern which is
3817 /// best at IR level, may not lower optimally. Thus, we want to unfold it.
3818 /// We are looking for the following pattern: (KeptBits is a constant)
3819 ///   (add %x, (1 << (KeptBits-1))) srccond (1 << KeptBits)
3820 /// KeptBits won't be bitwidth(x), that will be constant-folded to true/false.
3821 /// KeptBits also can't be 1, that would have been folded to  %x dstcond 0
3822 /// We will unfold it into the natural trunc+sext pattern:
3823 ///   ((%x << C) a>> C) dstcond %x
3824 /// Where  C = bitwidth(x) - KeptBits  and  C u< bitwidth(x)
optimizeSetCCOfSignedTruncationCheck(EVT SCCVT,SDValue N0,SDValue N1,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const3825 SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck(
3826     EVT SCCVT, SDValue N0, SDValue N1, ISD::CondCode Cond, DAGCombinerInfo &DCI,
3827     const SDLoc &DL) const {
3828   // We must be comparing with a constant.
3829   ConstantSDNode *C1;
3830   if (!(C1 = dyn_cast<ConstantSDNode>(N1)))
3831     return SDValue();
3832 
3833   // N0 should be:  add %x, (1 << (KeptBits-1))
3834   if (N0->getOpcode() != ISD::ADD)
3835     return SDValue();
3836 
3837   // And we must be 'add'ing a constant.
3838   ConstantSDNode *C01;
3839   if (!(C01 = dyn_cast<ConstantSDNode>(N0->getOperand(1))))
3840     return SDValue();
3841 
3842   SDValue X = N0->getOperand(0);
3843   EVT XVT = X.getValueType();
3844 
3845   // Validate constants ...
3846 
3847   APInt I1 = C1->getAPIntValue();
3848 
3849   ISD::CondCode NewCond;
3850   if (Cond == ISD::CondCode::SETULT) {
3851     NewCond = ISD::CondCode::SETEQ;
3852   } else if (Cond == ISD::CondCode::SETULE) {
3853     NewCond = ISD::CondCode::SETEQ;
3854     // But need to 'canonicalize' the constant.
3855     I1 += 1;
3856   } else if (Cond == ISD::CondCode::SETUGT) {
3857     NewCond = ISD::CondCode::SETNE;
3858     // But need to 'canonicalize' the constant.
3859     I1 += 1;
3860   } else if (Cond == ISD::CondCode::SETUGE) {
3861     NewCond = ISD::CondCode::SETNE;
3862   } else
3863     return SDValue();
3864 
3865   APInt I01 = C01->getAPIntValue();
3866 
3867   auto checkConstants = [&I1, &I01]() -> bool {
3868     // Both of them must be power-of-two, and the constant from setcc is bigger.
3869     return I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2();
3870   };
3871 
3872   if (checkConstants()) {
3873     // Great, e.g. got  icmp ult i16 (add i16 %x, 128), 256
3874   } else {
3875     // What if we invert constants? (and the target predicate)
3876     I1.negate();
3877     I01.negate();
3878     assert(XVT.isInteger());
3879     NewCond = getSetCCInverse(NewCond, XVT);
3880     if (!checkConstants())
3881       return SDValue();
3882     // Great, e.g. got  icmp uge i16 (add i16 %x, -128), -256
3883   }
3884 
3885   // They are power-of-two, so which bit is set?
3886   const unsigned KeptBits = I1.logBase2();
3887   const unsigned KeptBitsMinusOne = I01.logBase2();
3888 
3889   // Magic!
3890   if (KeptBits != (KeptBitsMinusOne + 1))
3891     return SDValue();
3892   assert(KeptBits > 0 && KeptBits < XVT.getSizeInBits() && "unreachable");
3893 
3894   // We don't want to do this in every single case.
3895   SelectionDAG &DAG = DCI.DAG;
3896   if (!DAG.getTargetLoweringInfo().shouldTransformSignedTruncationCheck(
3897           XVT, KeptBits))
3898     return SDValue();
3899 
3900   const unsigned MaskedBits = XVT.getSizeInBits() - KeptBits;
3901   assert(MaskedBits > 0 && MaskedBits < XVT.getSizeInBits() && "unreachable");
3902 
3903   // Unfold into:  ((%x << C) a>> C) cond %x
3904   // Where 'cond' will be either 'eq' or 'ne'.
3905   SDValue ShiftAmt = DAG.getConstant(MaskedBits, DL, XVT);
3906   SDValue T0 = DAG.getNode(ISD::SHL, DL, XVT, X, ShiftAmt);
3907   SDValue T1 = DAG.getNode(ISD::SRA, DL, XVT, T0, ShiftAmt);
3908   SDValue T2 = DAG.getSetCC(DL, SCCVT, T1, X, NewCond);
3909 
3910   return T2;
3911 }
3912 
3913 // (X & (C l>>/<< Y)) ==/!= 0  -->  ((X <</l>> Y) & C) ==/!= 0
optimizeSetCCByHoistingAndByConstFromLogicalShift(EVT SCCVT,SDValue N0,SDValue N1C,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const3914 SDValue TargetLowering::optimizeSetCCByHoistingAndByConstFromLogicalShift(
3915     EVT SCCVT, SDValue N0, SDValue N1C, ISD::CondCode Cond,
3916     DAGCombinerInfo &DCI, const SDLoc &DL) const {
3917   assert(isConstOrConstSplat(N1C) &&
3918          isConstOrConstSplat(N1C)->getAPIntValue().isZero() &&
3919          "Should be a comparison with 0.");
3920   assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
3921          "Valid only for [in]equality comparisons.");
3922 
3923   unsigned NewShiftOpcode;
3924   SDValue X, C, Y;
3925 
3926   SelectionDAG &DAG = DCI.DAG;
3927   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
3928 
3929   // Look for '(C l>>/<< Y)'.
3930   auto Match = [&NewShiftOpcode, &X, &C, &Y, &TLI, &DAG](SDValue V) {
3931     // The shift should be one-use.
3932     if (!V.hasOneUse())
3933       return false;
3934     unsigned OldShiftOpcode = V.getOpcode();
3935     switch (OldShiftOpcode) {
3936     case ISD::SHL:
3937       NewShiftOpcode = ISD::SRL;
3938       break;
3939     case ISD::SRL:
3940       NewShiftOpcode = ISD::SHL;
3941       break;
3942     default:
3943       return false; // must be a logical shift.
3944     }
3945     // We should be shifting a constant.
3946     // FIXME: best to use isConstantOrConstantVector().
3947     C = V.getOperand(0);
3948     ConstantSDNode *CC =
3949         isConstOrConstSplat(C, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
3950     if (!CC)
3951       return false;
3952     Y = V.getOperand(1);
3953 
3954     ConstantSDNode *XC =
3955         isConstOrConstSplat(X, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
3956     return TLI.shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
3957         X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG);
3958   };
3959 
3960   // LHS of comparison should be an one-use 'and'.
3961   if (N0.getOpcode() != ISD::AND || !N0.hasOneUse())
3962     return SDValue();
3963 
3964   X = N0.getOperand(0);
3965   SDValue Mask = N0.getOperand(1);
3966 
3967   // 'and' is commutative!
3968   if (!Match(Mask)) {
3969     std::swap(X, Mask);
3970     if (!Match(Mask))
3971       return SDValue();
3972   }
3973 
3974   EVT VT = X.getValueType();
3975 
3976   // Produce:
3977   // ((X 'OppositeShiftOpcode' Y) & C) Cond 0
3978   SDValue T0 = DAG.getNode(NewShiftOpcode, DL, VT, X, Y);
3979   SDValue T1 = DAG.getNode(ISD::AND, DL, VT, T0, C);
3980   SDValue T2 = DAG.getSetCC(DL, SCCVT, T1, N1C, Cond);
3981   return T2;
3982 }
3983 
3984 /// Try to fold an equality comparison with a {add/sub/xor} binary operation as
3985 /// the 1st operand (N0). Callers are expected to swap the N0/N1 parameters to
3986 /// handle the commuted versions of these patterns.
foldSetCCWithBinOp(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,DAGCombinerInfo & DCI) const3987 SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
3988                                            ISD::CondCode Cond, const SDLoc &DL,
3989                                            DAGCombinerInfo &DCI) const {
3990   unsigned BOpcode = N0.getOpcode();
3991   assert((BOpcode == ISD::ADD || BOpcode == ISD::SUB || BOpcode == ISD::XOR) &&
3992          "Unexpected binop");
3993   assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Unexpected condcode");
3994 
3995   // (X + Y) == X --> Y == 0
3996   // (X - Y) == X --> Y == 0
3997   // (X ^ Y) == X --> Y == 0
3998   SelectionDAG &DAG = DCI.DAG;
3999   EVT OpVT = N0.getValueType();
4000   SDValue X = N0.getOperand(0);
4001   SDValue Y = N0.getOperand(1);
4002   if (X == N1)
4003     return DAG.getSetCC(DL, VT, Y, DAG.getConstant(0, DL, OpVT), Cond);
4004 
4005   if (Y != N1)
4006     return SDValue();
4007 
4008   // (X + Y) == Y --> X == 0
4009   // (X ^ Y) == Y --> X == 0
4010   if (BOpcode == ISD::ADD || BOpcode == ISD::XOR)
4011     return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, OpVT), Cond);
4012 
4013   // The shift would not be valid if the operands are boolean (i1).
4014   if (!N0.hasOneUse() || OpVT.getScalarSizeInBits() == 1)
4015     return SDValue();
4016 
4017   // (X - Y) == Y --> X == Y << 1
4018   EVT ShiftVT = getShiftAmountTy(OpVT, DAG.getDataLayout(),
4019                                  !DCI.isBeforeLegalize());
4020   SDValue One = DAG.getConstant(1, DL, ShiftVT);
4021   SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
4022   if (!DCI.isCalledByLegalizer())
4023     DCI.AddToWorklist(YShl1.getNode());
4024   return DAG.getSetCC(DL, VT, X, YShl1, Cond);
4025 }
4026 
simplifySetCCWithCTPOP(const TargetLowering & TLI,EVT VT,SDValue N0,const APInt & C1,ISD::CondCode Cond,const SDLoc & dl,SelectionDAG & DAG)4027 static SDValue simplifySetCCWithCTPOP(const TargetLowering &TLI, EVT VT,
4028                                       SDValue N0, const APInt &C1,
4029                                       ISD::CondCode Cond, const SDLoc &dl,
4030                                       SelectionDAG &DAG) {
4031   // Look through truncs that don't change the value of a ctpop.
4032   // FIXME: Add vector support? Need to be careful with setcc result type below.
4033   SDValue CTPOP = N0;
4034   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && !VT.isVector() &&
4035       N0.getScalarValueSizeInBits() > Log2_32(N0.getOperand(0).getScalarValueSizeInBits()))
4036     CTPOP = N0.getOperand(0);
4037 
4038   if (CTPOP.getOpcode() != ISD::CTPOP || !CTPOP.hasOneUse())
4039     return SDValue();
4040 
4041   EVT CTVT = CTPOP.getValueType();
4042   SDValue CTOp = CTPOP.getOperand(0);
4043 
4044   // Expand a power-of-2-or-zero comparison based on ctpop:
4045   // (ctpop x) u< 2 -> (x & x-1) == 0
4046   // (ctpop x) u> 1 -> (x & x-1) != 0
4047   if (Cond == ISD::SETULT || Cond == ISD::SETUGT) {
4048     // Keep the CTPOP if it is a legal vector op.
4049     if (CTVT.isVector() && TLI.isOperationLegal(ISD::CTPOP, CTVT))
4050       return SDValue();
4051 
4052     unsigned CostLimit = TLI.getCustomCtpopCost(CTVT, Cond);
4053     if (C1.ugt(CostLimit + (Cond == ISD::SETULT)))
4054       return SDValue();
4055     if (C1 == 0 && (Cond == ISD::SETULT))
4056       return SDValue(); // This is handled elsewhere.
4057 
4058     unsigned Passes = C1.getLimitedValue() - (Cond == ISD::SETULT);
4059 
4060     SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4061     SDValue Result = CTOp;
4062     for (unsigned i = 0; i < Passes; i++) {
4063       SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, Result, NegOne);
4064       Result = DAG.getNode(ISD::AND, dl, CTVT, Result, Add);
4065     }
4066     ISD::CondCode CC = Cond == ISD::SETULT ? ISD::SETEQ : ISD::SETNE;
4067     return DAG.getSetCC(dl, VT, Result, DAG.getConstant(0, dl, CTVT), CC);
4068   }
4069 
4070   // Expand a power-of-2 comparison based on ctpop:
4071   // (ctpop x) == 1 --> (x != 0) && ((x & x-1) == 0)
4072   // (ctpop x) != 1 --> (x == 0) || ((x & x-1) != 0)
4073   if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && C1 == 1) {
4074     // Keep the CTPOP if it is legal.
4075     if (TLI.isOperationLegal(ISD::CTPOP, CTVT))
4076       return SDValue();
4077 
4078     SDValue Zero = DAG.getConstant(0, dl, CTVT);
4079     SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4080     assert(CTVT.isInteger());
4081     ISD::CondCode InvCond = ISD::getSetCCInverse(Cond, CTVT);
4082     SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, CTOp, NegOne);
4083     SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add);
4084     SDValue LHS = DAG.getSetCC(dl, VT, CTOp, Zero, InvCond);
4085     SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond);
4086     unsigned LogicOpcode = Cond == ISD::SETEQ ? ISD::AND : ISD::OR;
4087     return DAG.getNode(LogicOpcode, dl, VT, LHS, RHS);
4088   }
4089 
4090   return SDValue();
4091 }
4092 
foldSetCCWithRotate(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & dl,SelectionDAG & DAG)4093 static SDValue foldSetCCWithRotate(EVT VT, SDValue N0, SDValue N1,
4094                                    ISD::CondCode Cond, const SDLoc &dl,
4095                                    SelectionDAG &DAG) {
4096   if (Cond != ISD::SETEQ && Cond != ISD::SETNE)
4097     return SDValue();
4098 
4099   auto *C1 = isConstOrConstSplat(N1, /* AllowUndefs */ true);
4100   if (!C1 || !(C1->isZero() || C1->isAllOnes()))
4101     return SDValue();
4102 
4103   auto getRotateSource = [](SDValue X) {
4104     if (X.getOpcode() == ISD::ROTL || X.getOpcode() == ISD::ROTR)
4105       return X.getOperand(0);
4106     return SDValue();
4107   };
4108 
4109   // Peek through a rotated value compared against 0 or -1:
4110   // (rot X, Y) == 0/-1 --> X == 0/-1
4111   // (rot X, Y) != 0/-1 --> X != 0/-1
4112   if (SDValue R = getRotateSource(N0))
4113     return DAG.getSetCC(dl, VT, R, N1, Cond);
4114 
4115   // Peek through an 'or' of a rotated value compared against 0:
4116   // or (rot X, Y), Z ==/!= 0 --> (or X, Z) ==/!= 0
4117   // or Z, (rot X, Y) ==/!= 0 --> (or X, Z) ==/!= 0
4118   //
4119   // TODO: Add the 'and' with -1 sibling.
4120   // TODO: Recurse through a series of 'or' ops to find the rotate.
4121   EVT OpVT = N0.getValueType();
4122   if (N0.hasOneUse() && N0.getOpcode() == ISD::OR && C1->isZero()) {
4123     if (SDValue R = getRotateSource(N0.getOperand(0))) {
4124       SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, R, N0.getOperand(1));
4125       return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4126     }
4127     if (SDValue R = getRotateSource(N0.getOperand(1))) {
4128       SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, R, N0.getOperand(0));
4129       return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4130     }
4131   }
4132 
4133   return SDValue();
4134 }
4135 
foldSetCCWithFunnelShift(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & dl,SelectionDAG & DAG)4136 static SDValue foldSetCCWithFunnelShift(EVT VT, SDValue N0, SDValue N1,
4137                                         ISD::CondCode Cond, const SDLoc &dl,
4138                                         SelectionDAG &DAG) {
4139   // If we are testing for all-bits-clear, we might be able to do that with
4140   // less shifting since bit-order does not matter.
4141   if (Cond != ISD::SETEQ && Cond != ISD::SETNE)
4142     return SDValue();
4143 
4144   auto *C1 = isConstOrConstSplat(N1, /* AllowUndefs */ true);
4145   if (!C1 || !C1->isZero())
4146     return SDValue();
4147 
4148   if (!N0.hasOneUse() ||
4149       (N0.getOpcode() != ISD::FSHL && N0.getOpcode() != ISD::FSHR))
4150     return SDValue();
4151 
4152   unsigned BitWidth = N0.getScalarValueSizeInBits();
4153   auto *ShAmtC = isConstOrConstSplat(N0.getOperand(2));
4154   if (!ShAmtC || ShAmtC->getAPIntValue().uge(BitWidth))
4155     return SDValue();
4156 
4157   // Canonicalize fshr as fshl to reduce pattern-matching.
4158   unsigned ShAmt = ShAmtC->getZExtValue();
4159   if (N0.getOpcode() == ISD::FSHR)
4160     ShAmt = BitWidth - ShAmt;
4161 
4162   // Match an 'or' with a specific operand 'Other' in either commuted variant.
4163   SDValue X, Y;
4164   auto matchOr = [&X, &Y](SDValue Or, SDValue Other) {
4165     if (Or.getOpcode() != ISD::OR || !Or.hasOneUse())
4166       return false;
4167     if (Or.getOperand(0) == Other) {
4168       X = Or.getOperand(0);
4169       Y = Or.getOperand(1);
4170       return true;
4171     }
4172     if (Or.getOperand(1) == Other) {
4173       X = Or.getOperand(1);
4174       Y = Or.getOperand(0);
4175       return true;
4176     }
4177     return false;
4178   };
4179 
4180   EVT OpVT = N0.getValueType();
4181   EVT ShAmtVT = N0.getOperand(2).getValueType();
4182   SDValue F0 = N0.getOperand(0);
4183   SDValue F1 = N0.getOperand(1);
4184   if (matchOr(F0, F1)) {
4185     // fshl (or X, Y), X, C ==/!= 0 --> or (shl Y, C), X ==/!= 0
4186     SDValue NewShAmt = DAG.getConstant(ShAmt, dl, ShAmtVT);
4187     SDValue Shift = DAG.getNode(ISD::SHL, dl, OpVT, Y, NewShAmt);
4188     SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, Shift, X);
4189     return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4190   }
4191   if (matchOr(F1, F0)) {
4192     // fshl X, (or X, Y), C ==/!= 0 --> or (srl Y, BW-C), X ==/!= 0
4193     SDValue NewShAmt = DAG.getConstant(BitWidth - ShAmt, dl, ShAmtVT);
4194     SDValue Shift = DAG.getNode(ISD::SRL, dl, OpVT, Y, NewShAmt);
4195     SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, Shift, X);
4196     return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4197   }
4198 
4199   return SDValue();
4200 }
4201 
4202 /// Try to simplify a setcc built with the specified operands and cc. If it is
4203 /// unable to simplify it, return a null SDValue.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,bool foldBooleans,DAGCombinerInfo & DCI,const SDLoc & dl) const4204 SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
4205                                       ISD::CondCode Cond, bool foldBooleans,
4206                                       DAGCombinerInfo &DCI,
4207                                       const SDLoc &dl) const {
4208   SelectionDAG &DAG = DCI.DAG;
4209   const DataLayout &Layout = DAG.getDataLayout();
4210   EVT OpVT = N0.getValueType();
4211   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4212 
4213   // Constant fold or commute setcc.
4214   if (SDValue Fold = DAG.FoldSetCC(VT, N0, N1, Cond, dl))
4215     return Fold;
4216 
4217   bool N0ConstOrSplat =
4218       isConstOrConstSplat(N0, /*AllowUndefs*/ false, /*AllowTruncate*/ true);
4219   bool N1ConstOrSplat =
4220       isConstOrConstSplat(N1, /*AllowUndefs*/ false, /*AllowTruncate*/ true);
4221 
4222   // Ensure that the constant occurs on the RHS and fold constant comparisons.
4223   // TODO: Handle non-splat vector constants. All undef causes trouble.
4224   // FIXME: We can't yet fold constant scalable vector splats, so avoid an
4225   // infinite loop here when we encounter one.
4226   ISD::CondCode SwappedCC = ISD::getSetCCSwappedOperands(Cond);
4227   if (N0ConstOrSplat && (!OpVT.isScalableVector() || !N1ConstOrSplat) &&
4228       (DCI.isBeforeLegalizeOps() ||
4229        isCondCodeLegal(SwappedCC, N0.getSimpleValueType())))
4230     return DAG.getSetCC(dl, VT, N1, N0, SwappedCC);
4231 
4232   // If we have a subtract with the same 2 non-constant operands as this setcc
4233   // -- but in reverse order -- then try to commute the operands of this setcc
4234   // to match. A matching pair of setcc (cmp) and sub may be combined into 1
4235   // instruction on some targets.
4236   if (!N0ConstOrSplat && !N1ConstOrSplat &&
4237       (DCI.isBeforeLegalizeOps() ||
4238        isCondCodeLegal(SwappedCC, N0.getSimpleValueType())) &&
4239       DAG.doesNodeExist(ISD::SUB, DAG.getVTList(OpVT), {N1, N0}) &&
4240       !DAG.doesNodeExist(ISD::SUB, DAG.getVTList(OpVT), {N0, N1}))
4241     return DAG.getSetCC(dl, VT, N1, N0, SwappedCC);
4242 
4243   if (SDValue V = foldSetCCWithRotate(VT, N0, N1, Cond, dl, DAG))
4244     return V;
4245 
4246   if (SDValue V = foldSetCCWithFunnelShift(VT, N0, N1, Cond, dl, DAG))
4247     return V;
4248 
4249   if (auto *N1C = isConstOrConstSplat(N1)) {
4250     const APInt &C1 = N1C->getAPIntValue();
4251 
4252     // Optimize some CTPOP cases.
4253     if (SDValue V = simplifySetCCWithCTPOP(*this, VT, N0, C1, Cond, dl, DAG))
4254       return V;
4255 
4256     // For equality to 0 of a no-wrap multiply, decompose and test each op:
4257     // X * Y == 0 --> (X == 0) || (Y == 0)
4258     // X * Y != 0 --> (X != 0) && (Y != 0)
4259     // TODO: This bails out if minsize is set, but if the target doesn't have a
4260     //       single instruction multiply for this type, it would likely be
4261     //       smaller to decompose.
4262     if (C1.isZero() && (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4263         N0.getOpcode() == ISD::MUL && N0.hasOneUse() &&
4264         (N0->getFlags().hasNoUnsignedWrap() ||
4265          N0->getFlags().hasNoSignedWrap()) &&
4266         !Attr.hasFnAttr(Attribute::MinSize)) {
4267       SDValue IsXZero = DAG.getSetCC(dl, VT, N0.getOperand(0), N1, Cond);
4268       SDValue IsYZero = DAG.getSetCC(dl, VT, N0.getOperand(1), N1, Cond);
4269       unsigned LogicOp = Cond == ISD::SETEQ ? ISD::OR : ISD::AND;
4270       return DAG.getNode(LogicOp, dl, VT, IsXZero, IsYZero);
4271     }
4272 
4273     // If the LHS is '(srl (ctlz x), 5)', the RHS is 0/1, and this is an
4274     // equality comparison, then we're just comparing whether X itself is
4275     // zero.
4276     if (N0.getOpcode() == ISD::SRL && (C1.isZero() || C1.isOne()) &&
4277         N0.getOperand(0).getOpcode() == ISD::CTLZ &&
4278         isPowerOf2_32(N0.getScalarValueSizeInBits())) {
4279       if (ConstantSDNode *ShAmt = isConstOrConstSplat(N0.getOperand(1))) {
4280         if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4281             ShAmt->getAPIntValue() == Log2_32(N0.getScalarValueSizeInBits())) {
4282           if ((C1 == 0) == (Cond == ISD::SETEQ)) {
4283             // (srl (ctlz x), 5) == 0  -> X != 0
4284             // (srl (ctlz x), 5) != 1  -> X != 0
4285             Cond = ISD::SETNE;
4286           } else {
4287             // (srl (ctlz x), 5) != 0  -> X == 0
4288             // (srl (ctlz x), 5) == 1  -> X == 0
4289             Cond = ISD::SETEQ;
4290           }
4291           SDValue Zero = DAG.getConstant(0, dl, N0.getValueType());
4292           return DAG.getSetCC(dl, VT, N0.getOperand(0).getOperand(0), Zero,
4293                               Cond);
4294         }
4295       }
4296     }
4297   }
4298 
4299   // FIXME: Support vectors.
4300   if (auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode())) {
4301     const APInt &C1 = N1C->getAPIntValue();
4302 
4303     // (zext x) == C --> x == (trunc C)
4304     // (sext x) == C --> x == (trunc C)
4305     if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4306         DCI.isBeforeLegalize() && N0->hasOneUse()) {
4307       unsigned MinBits = N0.getValueSizeInBits();
4308       SDValue PreExt;
4309       bool Signed = false;
4310       if (N0->getOpcode() == ISD::ZERO_EXTEND) {
4311         // ZExt
4312         MinBits = N0->getOperand(0).getValueSizeInBits();
4313         PreExt = N0->getOperand(0);
4314       } else if (N0->getOpcode() == ISD::AND) {
4315         // DAGCombine turns costly ZExts into ANDs
4316         if (auto *C = dyn_cast<ConstantSDNode>(N0->getOperand(1)))
4317           if ((C->getAPIntValue()+1).isPowerOf2()) {
4318             MinBits = C->getAPIntValue().countTrailingOnes();
4319             PreExt = N0->getOperand(0);
4320           }
4321       } else if (N0->getOpcode() == ISD::SIGN_EXTEND) {
4322         // SExt
4323         MinBits = N0->getOperand(0).getValueSizeInBits();
4324         PreExt = N0->getOperand(0);
4325         Signed = true;
4326       } else if (auto *LN0 = dyn_cast<LoadSDNode>(N0)) {
4327         // ZEXTLOAD / SEXTLOAD
4328         if (LN0->getExtensionType() == ISD::ZEXTLOAD) {
4329           MinBits = LN0->getMemoryVT().getSizeInBits();
4330           PreExt = N0;
4331         } else if (LN0->getExtensionType() == ISD::SEXTLOAD) {
4332           Signed = true;
4333           MinBits = LN0->getMemoryVT().getSizeInBits();
4334           PreExt = N0;
4335         }
4336       }
4337 
4338       // Figure out how many bits we need to preserve this constant.
4339       unsigned ReqdBits = Signed ? C1.getMinSignedBits() : C1.getActiveBits();
4340 
4341       // Make sure we're not losing bits from the constant.
4342       if (MinBits > 0 &&
4343           MinBits < C1.getBitWidth() &&
4344           MinBits >= ReqdBits) {
4345         EVT MinVT = EVT::getIntegerVT(*DAG.getContext(), MinBits);
4346         if (isTypeDesirableForOp(ISD::SETCC, MinVT)) {
4347           // Will get folded away.
4348           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, MinVT, PreExt);
4349           if (MinBits == 1 && C1 == 1)
4350             // Invert the condition.
4351             return DAG.getSetCC(dl, VT, Trunc, DAG.getConstant(0, dl, MVT::i1),
4352                                 Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
4353           SDValue C = DAG.getConstant(C1.trunc(MinBits), dl, MinVT);
4354           return DAG.getSetCC(dl, VT, Trunc, C, Cond);
4355         }
4356 
4357         // If truncating the setcc operands is not desirable, we can still
4358         // simplify the expression in some cases:
4359         // setcc ([sz]ext (setcc x, y, cc)), 0, setne) -> setcc (x, y, cc)
4360         // setcc ([sz]ext (setcc x, y, cc)), 0, seteq) -> setcc (x, y, inv(cc))
4361         // setcc (zext (setcc x, y, cc)), 1, setne) -> setcc (x, y, inv(cc))
4362         // setcc (zext (setcc x, y, cc)), 1, seteq) -> setcc (x, y, cc)
4363         // setcc (sext (setcc x, y, cc)), -1, setne) -> setcc (x, y, inv(cc))
4364         // setcc (sext (setcc x, y, cc)), -1, seteq) -> setcc (x, y, cc)
4365         SDValue TopSetCC = N0->getOperand(0);
4366         unsigned N0Opc = N0->getOpcode();
4367         bool SExt = (N0Opc == ISD::SIGN_EXTEND);
4368         if (TopSetCC.getValueType() == MVT::i1 && VT == MVT::i1 &&
4369             TopSetCC.getOpcode() == ISD::SETCC &&
4370             (N0Opc == ISD::ZERO_EXTEND || N0Opc == ISD::SIGN_EXTEND) &&
4371             (isConstFalseVal(N1) ||
4372              isExtendedTrueVal(N1C, N0->getValueType(0), SExt))) {
4373 
4374           bool Inverse = (N1C->isZero() && Cond == ISD::SETEQ) ||
4375                          (!N1C->isZero() && Cond == ISD::SETNE);
4376 
4377           if (!Inverse)
4378             return TopSetCC;
4379 
4380           ISD::CondCode InvCond = ISD::getSetCCInverse(
4381               cast<CondCodeSDNode>(TopSetCC.getOperand(2))->get(),
4382               TopSetCC.getOperand(0).getValueType());
4383           return DAG.getSetCC(dl, VT, TopSetCC.getOperand(0),
4384                                       TopSetCC.getOperand(1),
4385                                       InvCond);
4386         }
4387       }
4388     }
4389 
4390     // If the LHS is '(and load, const)', the RHS is 0, the test is for
4391     // equality or unsigned, and all 1 bits of the const are in the same
4392     // partial word, see if we can shorten the load.
4393     if (DCI.isBeforeLegalize() &&
4394         !ISD::isSignedIntSetCC(Cond) &&
4395         N0.getOpcode() == ISD::AND && C1 == 0 &&
4396         N0.getNode()->hasOneUse() &&
4397         isa<LoadSDNode>(N0.getOperand(0)) &&
4398         N0.getOperand(0).getNode()->hasOneUse() &&
4399         isa<ConstantSDNode>(N0.getOperand(1))) {
4400       LoadSDNode *Lod = cast<LoadSDNode>(N0.getOperand(0));
4401       APInt bestMask;
4402       unsigned bestWidth = 0, bestOffset = 0;
4403       if (Lod->isSimple() && Lod->isUnindexed()) {
4404         unsigned origWidth = N0.getValueSizeInBits();
4405         unsigned maskWidth = origWidth;
4406         // We can narrow (e.g.) 16-bit extending loads on 32-bit target to
4407         // 8 bits, but have to be careful...
4408         if (Lod->getExtensionType() != ISD::NON_EXTLOAD)
4409           origWidth = Lod->getMemoryVT().getSizeInBits();
4410         const APInt &Mask = N0.getConstantOperandAPInt(1);
4411         for (unsigned width = origWidth / 2; width>=8; width /= 2) {
4412           APInt newMask = APInt::getLowBitsSet(maskWidth, width);
4413           for (unsigned offset=0; offset<origWidth/width; offset++) {
4414             if (Mask.isSubsetOf(newMask)) {
4415               if (Layout.isLittleEndian())
4416                 bestOffset = (uint64_t)offset * (width/8);
4417               else
4418                 bestOffset = (origWidth/width - offset - 1) * (width/8);
4419               bestMask = Mask.lshr(offset * (width/8) * 8);
4420               bestWidth = width;
4421               break;
4422             }
4423             newMask <<= width;
4424           }
4425         }
4426       }
4427       if (bestWidth) {
4428         EVT newVT = EVT::getIntegerVT(*DAG.getContext(), bestWidth);
4429         if (newVT.isRound() &&
4430             shouldReduceLoadWidth(Lod, ISD::NON_EXTLOAD, newVT)) {
4431           SDValue Ptr = Lod->getBasePtr();
4432           if (bestOffset != 0)
4433             Ptr =
4434                 DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(bestOffset), dl);
4435           SDValue NewLoad =
4436               DAG.getLoad(newVT, dl, Lod->getChain(), Ptr,
4437                           Lod->getPointerInfo().getWithOffset(bestOffset),
4438                           Lod->getOriginalAlign());
4439           return DAG.getSetCC(dl, VT,
4440                               DAG.getNode(ISD::AND, dl, newVT, NewLoad,
4441                                       DAG.getConstant(bestMask.trunc(bestWidth),
4442                                                       dl, newVT)),
4443                               DAG.getConstant(0LL, dl, newVT), Cond);
4444         }
4445       }
4446     }
4447 
4448     // If the LHS is a ZERO_EXTEND, perform the comparison on the input.
4449     if (N0.getOpcode() == ISD::ZERO_EXTEND) {
4450       unsigned InSize = N0.getOperand(0).getValueSizeInBits();
4451 
4452       // If the comparison constant has bits in the upper part, the
4453       // zero-extended value could never match.
4454       if (C1.intersects(APInt::getHighBitsSet(C1.getBitWidth(),
4455                                               C1.getBitWidth() - InSize))) {
4456         switch (Cond) {
4457         case ISD::SETUGT:
4458         case ISD::SETUGE:
4459         case ISD::SETEQ:
4460           return DAG.getConstant(0, dl, VT);
4461         case ISD::SETULT:
4462         case ISD::SETULE:
4463         case ISD::SETNE:
4464           return DAG.getConstant(1, dl, VT);
4465         case ISD::SETGT:
4466         case ISD::SETGE:
4467           // True if the sign bit of C1 is set.
4468           return DAG.getConstant(C1.isNegative(), dl, VT);
4469         case ISD::SETLT:
4470         case ISD::SETLE:
4471           // True if the sign bit of C1 isn't set.
4472           return DAG.getConstant(C1.isNonNegative(), dl, VT);
4473         default:
4474           break;
4475         }
4476       }
4477 
4478       // Otherwise, we can perform the comparison with the low bits.
4479       switch (Cond) {
4480       case ISD::SETEQ:
4481       case ISD::SETNE:
4482       case ISD::SETUGT:
4483       case ISD::SETUGE:
4484       case ISD::SETULT:
4485       case ISD::SETULE: {
4486         EVT newVT = N0.getOperand(0).getValueType();
4487         if (DCI.isBeforeLegalizeOps() ||
4488             (isOperationLegal(ISD::SETCC, newVT) &&
4489              isCondCodeLegal(Cond, newVT.getSimpleVT()))) {
4490           EVT NewSetCCVT = getSetCCResultType(Layout, *DAG.getContext(), newVT);
4491           SDValue NewConst = DAG.getConstant(C1.trunc(InSize), dl, newVT);
4492 
4493           SDValue NewSetCC = DAG.getSetCC(dl, NewSetCCVT, N0.getOperand(0),
4494                                           NewConst, Cond);
4495           return DAG.getBoolExtOrTrunc(NewSetCC, dl, VT, N0.getValueType());
4496         }
4497         break;
4498       }
4499       default:
4500         break; // todo, be more careful with signed comparisons
4501       }
4502     } else if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
4503                (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4504                !isSExtCheaperThanZExt(cast<VTSDNode>(N0.getOperand(1))->getVT(),
4505                                       OpVT)) {
4506       EVT ExtSrcTy = cast<VTSDNode>(N0.getOperand(1))->getVT();
4507       unsigned ExtSrcTyBits = ExtSrcTy.getSizeInBits();
4508       EVT ExtDstTy = N0.getValueType();
4509       unsigned ExtDstTyBits = ExtDstTy.getSizeInBits();
4510 
4511       // If the constant doesn't fit into the number of bits for the source of
4512       // the sign extension, it is impossible for both sides to be equal.
4513       if (C1.getMinSignedBits() > ExtSrcTyBits)
4514         return DAG.getBoolConstant(Cond == ISD::SETNE, dl, VT, OpVT);
4515 
4516       assert(ExtDstTy == N0.getOperand(0).getValueType() &&
4517              ExtDstTy != ExtSrcTy && "Unexpected types!");
4518       APInt Imm = APInt::getLowBitsSet(ExtDstTyBits, ExtSrcTyBits);
4519       SDValue ZextOp = DAG.getNode(ISD::AND, dl, ExtDstTy, N0.getOperand(0),
4520                                    DAG.getConstant(Imm, dl, ExtDstTy));
4521       if (!DCI.isCalledByLegalizer())
4522         DCI.AddToWorklist(ZextOp.getNode());
4523       // Otherwise, make this a use of a zext.
4524       return DAG.getSetCC(dl, VT, ZextOp,
4525                           DAG.getConstant(C1 & Imm, dl, ExtDstTy), Cond);
4526     } else if ((N1C->isZero() || N1C->isOne()) &&
4527                (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
4528       // SETCC (SETCC), [0|1], [EQ|NE]  -> SETCC
4529       if (N0.getOpcode() == ISD::SETCC &&
4530           isTypeLegal(VT) && VT.bitsLE(N0.getValueType()) &&
4531           (N0.getValueType() == MVT::i1 ||
4532            getBooleanContents(N0.getOperand(0).getValueType()) ==
4533                        ZeroOrOneBooleanContent)) {
4534         bool TrueWhenTrue = (Cond == ISD::SETEQ) ^ (!N1C->isOne());
4535         if (TrueWhenTrue)
4536           return DAG.getNode(ISD::TRUNCATE, dl, VT, N0);
4537         // Invert the condition.
4538         ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
4539         CC = ISD::getSetCCInverse(CC, N0.getOperand(0).getValueType());
4540         if (DCI.isBeforeLegalizeOps() ||
4541             isCondCodeLegal(CC, N0.getOperand(0).getSimpleValueType()))
4542           return DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC);
4543       }
4544 
4545       if ((N0.getOpcode() == ISD::XOR ||
4546            (N0.getOpcode() == ISD::AND &&
4547             N0.getOperand(0).getOpcode() == ISD::XOR &&
4548             N0.getOperand(1) == N0.getOperand(0).getOperand(1))) &&
4549           isOneConstant(N0.getOperand(1))) {
4550         // If this is (X^1) == 0/1, swap the RHS and eliminate the xor.  We
4551         // can only do this if the top bits are known zero.
4552         unsigned BitWidth = N0.getValueSizeInBits();
4553         if (DAG.MaskedValueIsZero(N0,
4554                                   APInt::getHighBitsSet(BitWidth,
4555                                                         BitWidth-1))) {
4556           // Okay, get the un-inverted input value.
4557           SDValue Val;
4558           if (N0.getOpcode() == ISD::XOR) {
4559             Val = N0.getOperand(0);
4560           } else {
4561             assert(N0.getOpcode() == ISD::AND &&
4562                     N0.getOperand(0).getOpcode() == ISD::XOR);
4563             // ((X^1)&1)^1 -> X & 1
4564             Val = DAG.getNode(ISD::AND, dl, N0.getValueType(),
4565                               N0.getOperand(0).getOperand(0),
4566                               N0.getOperand(1));
4567           }
4568 
4569           return DAG.getSetCC(dl, VT, Val, N1,
4570                               Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
4571         }
4572       } else if (N1C->isOne()) {
4573         SDValue Op0 = N0;
4574         if (Op0.getOpcode() == ISD::TRUNCATE)
4575           Op0 = Op0.getOperand(0);
4576 
4577         if ((Op0.getOpcode() == ISD::XOR) &&
4578             Op0.getOperand(0).getOpcode() == ISD::SETCC &&
4579             Op0.getOperand(1).getOpcode() == ISD::SETCC) {
4580           SDValue XorLHS = Op0.getOperand(0);
4581           SDValue XorRHS = Op0.getOperand(1);
4582           // Ensure that the input setccs return an i1 type or 0/1 value.
4583           if (Op0.getValueType() == MVT::i1 ||
4584               (getBooleanContents(XorLHS.getOperand(0).getValueType()) ==
4585                       ZeroOrOneBooleanContent &&
4586                getBooleanContents(XorRHS.getOperand(0).getValueType()) ==
4587                         ZeroOrOneBooleanContent)) {
4588             // (xor (setcc), (setcc)) == / != 1 -> (setcc) != / == (setcc)
4589             Cond = (Cond == ISD::SETEQ) ? ISD::SETNE : ISD::SETEQ;
4590             return DAG.getSetCC(dl, VT, XorLHS, XorRHS, Cond);
4591           }
4592         }
4593         if (Op0.getOpcode() == ISD::AND && isOneConstant(Op0.getOperand(1))) {
4594           // If this is (X&1) == / != 1, normalize it to (X&1) != / == 0.
4595           if (Op0.getValueType().bitsGT(VT))
4596             Op0 = DAG.getNode(ISD::AND, dl, VT,
4597                           DAG.getNode(ISD::TRUNCATE, dl, VT, Op0.getOperand(0)),
4598                           DAG.getConstant(1, dl, VT));
4599           else if (Op0.getValueType().bitsLT(VT))
4600             Op0 = DAG.getNode(ISD::AND, dl, VT,
4601                         DAG.getNode(ISD::ANY_EXTEND, dl, VT, Op0.getOperand(0)),
4602                         DAG.getConstant(1, dl, VT));
4603 
4604           return DAG.getSetCC(dl, VT, Op0,
4605                               DAG.getConstant(0, dl, Op0.getValueType()),
4606                               Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
4607         }
4608         if (Op0.getOpcode() == ISD::AssertZext &&
4609             cast<VTSDNode>(Op0.getOperand(1))->getVT() == MVT::i1)
4610           return DAG.getSetCC(dl, VT, Op0,
4611                               DAG.getConstant(0, dl, Op0.getValueType()),
4612                               Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
4613       }
4614     }
4615 
4616     // Given:
4617     //   icmp eq/ne (urem %x, %y), 0
4618     // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem':
4619     //   icmp eq/ne %x, 0
4620     if (N0.getOpcode() == ISD::UREM && N1C->isZero() &&
4621         (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
4622       KnownBits XKnown = DAG.computeKnownBits(N0.getOperand(0));
4623       KnownBits YKnown = DAG.computeKnownBits(N0.getOperand(1));
4624       if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2)
4625         return DAG.getSetCC(dl, VT, N0.getOperand(0), N1, Cond);
4626     }
4627 
4628     // Fold set_cc seteq (ashr X, BW-1), -1 -> set_cc setlt X, 0
4629     //  and set_cc setne (ashr X, BW-1), -1 -> set_cc setge X, 0
4630     if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4631         N0.getOpcode() == ISD::SRA && isa<ConstantSDNode>(N0.getOperand(1)) &&
4632         N0.getConstantOperandAPInt(1) == OpVT.getScalarSizeInBits() - 1 &&
4633         N1C && N1C->isAllOnes()) {
4634       return DAG.getSetCC(dl, VT, N0.getOperand(0),
4635                           DAG.getConstant(0, dl, OpVT),
4636                           Cond == ISD::SETEQ ? ISD::SETLT : ISD::SETGE);
4637     }
4638 
4639     if (SDValue V =
4640             optimizeSetCCOfSignedTruncationCheck(VT, N0, N1, Cond, DCI, dl))
4641       return V;
4642   }
4643 
4644   // These simplifications apply to splat vectors as well.
4645   // TODO: Handle more splat vector cases.
4646   if (auto *N1C = isConstOrConstSplat(N1)) {
4647     const APInt &C1 = N1C->getAPIntValue();
4648 
4649     APInt MinVal, MaxVal;
4650     unsigned OperandBitSize = N1C->getValueType(0).getScalarSizeInBits();
4651     if (ISD::isSignedIntSetCC(Cond)) {
4652       MinVal = APInt::getSignedMinValue(OperandBitSize);
4653       MaxVal = APInt::getSignedMaxValue(OperandBitSize);
4654     } else {
4655       MinVal = APInt::getMinValue(OperandBitSize);
4656       MaxVal = APInt::getMaxValue(OperandBitSize);
4657     }
4658 
4659     // Canonicalize GE/LE comparisons to use GT/LT comparisons.
4660     if (Cond == ISD::SETGE || Cond == ISD::SETUGE) {
4661       // X >= MIN --> true
4662       if (C1 == MinVal)
4663         return DAG.getBoolConstant(true, dl, VT, OpVT);
4664 
4665       if (!VT.isVector()) { // TODO: Support this for vectors.
4666         // X >= C0 --> X > (C0 - 1)
4667         APInt C = C1 - 1;
4668         ISD::CondCode NewCC = (Cond == ISD::SETGE) ? ISD::SETGT : ISD::SETUGT;
4669         if ((DCI.isBeforeLegalizeOps() ||
4670              isCondCodeLegal(NewCC, VT.getSimpleVT())) &&
4671             (!N1C->isOpaque() || (C.getBitWidth() <= 64 &&
4672                                   isLegalICmpImmediate(C.getSExtValue())))) {
4673           return DAG.getSetCC(dl, VT, N0,
4674                               DAG.getConstant(C, dl, N1.getValueType()),
4675                               NewCC);
4676         }
4677       }
4678     }
4679 
4680     if (Cond == ISD::SETLE || Cond == ISD::SETULE) {
4681       // X <= MAX --> true
4682       if (C1 == MaxVal)
4683         return DAG.getBoolConstant(true, dl, VT, OpVT);
4684 
4685       // X <= C0 --> X < (C0 + 1)
4686       if (!VT.isVector()) { // TODO: Support this for vectors.
4687         APInt C = C1 + 1;
4688         ISD::CondCode NewCC = (Cond == ISD::SETLE) ? ISD::SETLT : ISD::SETULT;
4689         if ((DCI.isBeforeLegalizeOps() ||
4690              isCondCodeLegal(NewCC, VT.getSimpleVT())) &&
4691             (!N1C->isOpaque() || (C.getBitWidth() <= 64 &&
4692                                   isLegalICmpImmediate(C.getSExtValue())))) {
4693           return DAG.getSetCC(dl, VT, N0,
4694                               DAG.getConstant(C, dl, N1.getValueType()),
4695                               NewCC);
4696         }
4697       }
4698     }
4699 
4700     if (Cond == ISD::SETLT || Cond == ISD::SETULT) {
4701       if (C1 == MinVal)
4702         return DAG.getBoolConstant(false, dl, VT, OpVT); // X < MIN --> false
4703 
4704       // TODO: Support this for vectors after legalize ops.
4705       if (!VT.isVector() || DCI.isBeforeLegalizeOps()) {
4706         // Canonicalize setlt X, Max --> setne X, Max
4707         if (C1 == MaxVal)
4708           return DAG.getSetCC(dl, VT, N0, N1, ISD::SETNE);
4709 
4710         // If we have setult X, 1, turn it into seteq X, 0
4711         if (C1 == MinVal+1)
4712           return DAG.getSetCC(dl, VT, N0,
4713                               DAG.getConstant(MinVal, dl, N0.getValueType()),
4714                               ISD::SETEQ);
4715       }
4716     }
4717 
4718     if (Cond == ISD::SETGT || Cond == ISD::SETUGT) {
4719       if (C1 == MaxVal)
4720         return DAG.getBoolConstant(false, dl, VT, OpVT); // X > MAX --> false
4721 
4722       // TODO: Support this for vectors after legalize ops.
4723       if (!VT.isVector() || DCI.isBeforeLegalizeOps()) {
4724         // Canonicalize setgt X, Min --> setne X, Min
4725         if (C1 == MinVal)
4726           return DAG.getSetCC(dl, VT, N0, N1, ISD::SETNE);
4727 
4728         // If we have setugt X, Max-1, turn it into seteq X, Max
4729         if (C1 == MaxVal-1)
4730           return DAG.getSetCC(dl, VT, N0,
4731                               DAG.getConstant(MaxVal, dl, N0.getValueType()),
4732                               ISD::SETEQ);
4733       }
4734     }
4735 
4736     if (Cond == ISD::SETEQ || Cond == ISD::SETNE) {
4737       // (X & (C l>>/<< Y)) ==/!= 0  -->  ((X <</l>> Y) & C) ==/!= 0
4738       if (C1.isZero())
4739         if (SDValue CC = optimizeSetCCByHoistingAndByConstFromLogicalShift(
4740                 VT, N0, N1, Cond, DCI, dl))
4741           return CC;
4742 
4743       // For all/any comparisons, replace or(x,shl(y,bw/2)) with and/or(x,y).
4744       // For example, when high 32-bits of i64 X are known clear:
4745       // all bits clear: (X | (Y<<32)) ==  0 --> (X | Y) ==  0
4746       // all bits set:   (X | (Y<<32)) == -1 --> (X & Y) == -1
4747       bool CmpZero = N1C->getAPIntValue().isZero();
4748       bool CmpNegOne = N1C->getAPIntValue().isAllOnes();
4749       if ((CmpZero || CmpNegOne) && N0.hasOneUse()) {
4750         // Match or(lo,shl(hi,bw/2)) pattern.
4751         auto IsConcat = [&](SDValue V, SDValue &Lo, SDValue &Hi) {
4752           unsigned EltBits = V.getScalarValueSizeInBits();
4753           if (V.getOpcode() != ISD::OR || (EltBits % 2) != 0)
4754             return false;
4755           SDValue LHS = V.getOperand(0);
4756           SDValue RHS = V.getOperand(1);
4757           APInt HiBits = APInt::getHighBitsSet(EltBits, EltBits / 2);
4758           // Unshifted element must have zero upperbits.
4759           if (RHS.getOpcode() == ISD::SHL &&
4760               isa<ConstantSDNode>(RHS.getOperand(1)) &&
4761               RHS.getConstantOperandAPInt(1) == (EltBits / 2) &&
4762               DAG.MaskedValueIsZero(LHS, HiBits)) {
4763             Lo = LHS;
4764             Hi = RHS.getOperand(0);
4765             return true;
4766           }
4767           if (LHS.getOpcode() == ISD::SHL &&
4768               isa<ConstantSDNode>(LHS.getOperand(1)) &&
4769               LHS.getConstantOperandAPInt(1) == (EltBits / 2) &&
4770               DAG.MaskedValueIsZero(RHS, HiBits)) {
4771             Lo = RHS;
4772             Hi = LHS.getOperand(0);
4773             return true;
4774           }
4775           return false;
4776         };
4777 
4778         auto MergeConcat = [&](SDValue Lo, SDValue Hi) {
4779           unsigned EltBits = N0.getScalarValueSizeInBits();
4780           unsigned HalfBits = EltBits / 2;
4781           APInt HiBits = APInt::getHighBitsSet(EltBits, HalfBits);
4782           SDValue LoBits = DAG.getConstant(~HiBits, dl, OpVT);
4783           SDValue HiMask = DAG.getNode(ISD::AND, dl, OpVT, Hi, LoBits);
4784           SDValue NewN0 =
4785               DAG.getNode(CmpZero ? ISD::OR : ISD::AND, dl, OpVT, Lo, HiMask);
4786           SDValue NewN1 = CmpZero ? DAG.getConstant(0, dl, OpVT) : LoBits;
4787           return DAG.getSetCC(dl, VT, NewN0, NewN1, Cond);
4788         };
4789 
4790         SDValue Lo, Hi;
4791         if (IsConcat(N0, Lo, Hi))
4792           return MergeConcat(Lo, Hi);
4793 
4794         if (N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR) {
4795           SDValue Lo0, Lo1, Hi0, Hi1;
4796           if (IsConcat(N0.getOperand(0), Lo0, Hi0) &&
4797               IsConcat(N0.getOperand(1), Lo1, Hi1)) {
4798             return MergeConcat(DAG.getNode(N0.getOpcode(), dl, OpVT, Lo0, Lo1),
4799                                DAG.getNode(N0.getOpcode(), dl, OpVT, Hi0, Hi1));
4800           }
4801         }
4802       }
4803     }
4804 
4805     // If we have "setcc X, C0", check to see if we can shrink the immediate
4806     // by changing cc.
4807     // TODO: Support this for vectors after legalize ops.
4808     if (!VT.isVector() || DCI.isBeforeLegalizeOps()) {
4809       // SETUGT X, SINTMAX  -> SETLT X, 0
4810       // SETUGE X, SINTMIN -> SETLT X, 0
4811       if ((Cond == ISD::SETUGT && C1.isMaxSignedValue()) ||
4812           (Cond == ISD::SETUGE && C1.isMinSignedValue()))
4813         return DAG.getSetCC(dl, VT, N0,
4814                             DAG.getConstant(0, dl, N1.getValueType()),
4815                             ISD::SETLT);
4816 
4817       // SETULT X, SINTMIN  -> SETGT X, -1
4818       // SETULE X, SINTMAX  -> SETGT X, -1
4819       if ((Cond == ISD::SETULT && C1.isMinSignedValue()) ||
4820           (Cond == ISD::SETULE && C1.isMaxSignedValue()))
4821         return DAG.getSetCC(dl, VT, N0,
4822                             DAG.getAllOnesConstant(dl, N1.getValueType()),
4823                             ISD::SETGT);
4824     }
4825   }
4826 
4827   // Back to non-vector simplifications.
4828   // TODO: Can we do these for vector splats?
4829   if (auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode())) {
4830     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4831     const APInt &C1 = N1C->getAPIntValue();
4832     EVT ShValTy = N0.getValueType();
4833 
4834     // Fold bit comparisons when we can. This will result in an
4835     // incorrect value when boolean false is negative one, unless
4836     // the bitsize is 1 in which case the false value is the same
4837     // in practice regardless of the representation.
4838     if ((VT.getSizeInBits() == 1 ||
4839          getBooleanContents(N0.getValueType()) == ZeroOrOneBooleanContent) &&
4840         (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4841         (VT == ShValTy || (isTypeLegal(VT) && VT.bitsLE(ShValTy))) &&
4842         N0.getOpcode() == ISD::AND) {
4843       if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4844         EVT ShiftTy =
4845             getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
4846         if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0  -->  (X & 8) >> 3
4847           // Perform the xform if the AND RHS is a single bit.
4848           unsigned ShCt = AndRHS->getAPIntValue().logBase2();
4849           if (AndRHS->getAPIntValue().isPowerOf2() &&
4850               !TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
4851             return DAG.getNode(ISD::TRUNCATE, dl, VT,
4852                                DAG.getNode(ISD::SRL, dl, ShValTy, N0,
4853                                            DAG.getConstant(ShCt, dl, ShiftTy)));
4854           }
4855         } else if (Cond == ISD::SETEQ && C1 == AndRHS->getAPIntValue()) {
4856           // (X & 8) == 8  -->  (X & 8) >> 3
4857           // Perform the xform if C1 is a single bit.
4858           unsigned ShCt = C1.logBase2();
4859           if (C1.isPowerOf2() &&
4860               !TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
4861             return DAG.getNode(ISD::TRUNCATE, dl, VT,
4862                                DAG.getNode(ISD::SRL, dl, ShValTy, N0,
4863                                            DAG.getConstant(ShCt, dl, ShiftTy)));
4864           }
4865         }
4866       }
4867     }
4868 
4869     if (C1.getMinSignedBits() <= 64 &&
4870         !isLegalICmpImmediate(C1.getSExtValue())) {
4871       EVT ShiftTy = getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
4872       // (X & -256) == 256 -> (X >> 8) == 1
4873       if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4874           N0.getOpcode() == ISD::AND && N0.hasOneUse()) {
4875         if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
4876           const APInt &AndRHSC = AndRHS->getAPIntValue();
4877           if (AndRHSC.isNegatedPowerOf2() && (AndRHSC & C1) == C1) {
4878             unsigned ShiftBits = AndRHSC.countTrailingZeros();
4879             if (!TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
4880               SDValue Shift =
4881                 DAG.getNode(ISD::SRL, dl, ShValTy, N0.getOperand(0),
4882                             DAG.getConstant(ShiftBits, dl, ShiftTy));
4883               SDValue CmpRHS = DAG.getConstant(C1.lshr(ShiftBits), dl, ShValTy);
4884               return DAG.getSetCC(dl, VT, Shift, CmpRHS, Cond);
4885             }
4886           }
4887         }
4888       } else if (Cond == ISD::SETULT || Cond == ISD::SETUGE ||
4889                  Cond == ISD::SETULE || Cond == ISD::SETUGT) {
4890         bool AdjOne = (Cond == ISD::SETULE || Cond == ISD::SETUGT);
4891         // X <  0x100000000 -> (X >> 32) <  1
4892         // X >= 0x100000000 -> (X >> 32) >= 1
4893         // X <= 0x0ffffffff -> (X >> 32) <  1
4894         // X >  0x0ffffffff -> (X >> 32) >= 1
4895         unsigned ShiftBits;
4896         APInt NewC = C1;
4897         ISD::CondCode NewCond = Cond;
4898         if (AdjOne) {
4899           ShiftBits = C1.countTrailingOnes();
4900           NewC = NewC + 1;
4901           NewCond = (Cond == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE;
4902         } else {
4903           ShiftBits = C1.countTrailingZeros();
4904         }
4905         NewC.lshrInPlace(ShiftBits);
4906         if (ShiftBits && NewC.getMinSignedBits() <= 64 &&
4907             isLegalICmpImmediate(NewC.getSExtValue()) &&
4908             !TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
4909           SDValue Shift = DAG.getNode(ISD::SRL, dl, ShValTy, N0,
4910                                       DAG.getConstant(ShiftBits, dl, ShiftTy));
4911           SDValue CmpRHS = DAG.getConstant(NewC, dl, ShValTy);
4912           return DAG.getSetCC(dl, VT, Shift, CmpRHS, NewCond);
4913         }
4914       }
4915     }
4916   }
4917 
4918   if (!isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1)) {
4919     auto *CFP = cast<ConstantFPSDNode>(N1);
4920     assert(!CFP->getValueAPF().isNaN() && "Unexpected NaN value");
4921 
4922     // Otherwise, we know the RHS is not a NaN.  Simplify the node to drop the
4923     // constant if knowing that the operand is non-nan is enough.  We prefer to
4924     // have SETO(x,x) instead of SETO(x, 0.0) because this avoids having to
4925     // materialize 0.0.
4926     if (Cond == ISD::SETO || Cond == ISD::SETUO)
4927       return DAG.getSetCC(dl, VT, N0, N0, Cond);
4928 
4929     // setcc (fneg x), C -> setcc swap(pred) x, -C
4930     if (N0.getOpcode() == ISD::FNEG) {
4931       ISD::CondCode SwapCond = ISD::getSetCCSwappedOperands(Cond);
4932       if (DCI.isBeforeLegalizeOps() ||
4933           isCondCodeLegal(SwapCond, N0.getSimpleValueType())) {
4934         SDValue NegN1 = DAG.getNode(ISD::FNEG, dl, N0.getValueType(), N1);
4935         return DAG.getSetCC(dl, VT, N0.getOperand(0), NegN1, SwapCond);
4936       }
4937     }
4938 
4939     // If the condition is not legal, see if we can find an equivalent one
4940     // which is legal.
4941     if (!isCondCodeLegal(Cond, N0.getSimpleValueType())) {
4942       // If the comparison was an awkward floating-point == or != and one of
4943       // the comparison operands is infinity or negative infinity, convert the
4944       // condition to a less-awkward <= or >=.
4945       if (CFP->getValueAPF().isInfinity()) {
4946         bool IsNegInf = CFP->getValueAPF().isNegative();
4947         ISD::CondCode NewCond = ISD::SETCC_INVALID;
4948         switch (Cond) {
4949         case ISD::SETOEQ: NewCond = IsNegInf ? ISD::SETOLE : ISD::SETOGE; break;
4950         case ISD::SETUEQ: NewCond = IsNegInf ? ISD::SETULE : ISD::SETUGE; break;
4951         case ISD::SETUNE: NewCond = IsNegInf ? ISD::SETUGT : ISD::SETULT; break;
4952         case ISD::SETONE: NewCond = IsNegInf ? ISD::SETOGT : ISD::SETOLT; break;
4953         default: break;
4954         }
4955         if (NewCond != ISD::SETCC_INVALID &&
4956             isCondCodeLegal(NewCond, N0.getSimpleValueType()))
4957           return DAG.getSetCC(dl, VT, N0, N1, NewCond);
4958       }
4959     }
4960   }
4961 
4962   if (N0 == N1) {
4963     // The sext(setcc()) => setcc() optimization relies on the appropriate
4964     // constant being emitted.
4965     assert(!N0.getValueType().isInteger() &&
4966            "Integer types should be handled by FoldSetCC");
4967 
4968     bool EqTrue = ISD::isTrueWhenEqual(Cond);
4969     unsigned UOF = ISD::getUnorderedFlavor(Cond);
4970     if (UOF == 2) // FP operators that are undefined on NaNs.
4971       return DAG.getBoolConstant(EqTrue, dl, VT, OpVT);
4972     if (UOF == unsigned(EqTrue))
4973       return DAG.getBoolConstant(EqTrue, dl, VT, OpVT);
4974     // Otherwise, we can't fold it.  However, we can simplify it to SETUO/SETO
4975     // if it is not already.
4976     ISD::CondCode NewCond = UOF == 0 ? ISD::SETO : ISD::SETUO;
4977     if (NewCond != Cond &&
4978         (DCI.isBeforeLegalizeOps() ||
4979                             isCondCodeLegal(NewCond, N0.getSimpleValueType())))
4980       return DAG.getSetCC(dl, VT, N0, N1, NewCond);
4981   }
4982 
4983   if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4984       N0.getValueType().isInteger()) {
4985     if (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB ||
4986         N0.getOpcode() == ISD::XOR) {
4987       // Simplify (X+Y) == (X+Z) -->  Y == Z
4988       if (N0.getOpcode() == N1.getOpcode()) {
4989         if (N0.getOperand(0) == N1.getOperand(0))
4990           return DAG.getSetCC(dl, VT, N0.getOperand(1), N1.getOperand(1), Cond);
4991         if (N0.getOperand(1) == N1.getOperand(1))
4992           return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
4993         if (isCommutativeBinOp(N0.getOpcode())) {
4994           // If X op Y == Y op X, try other combinations.
4995           if (N0.getOperand(0) == N1.getOperand(1))
4996             return DAG.getSetCC(dl, VT, N0.getOperand(1), N1.getOperand(0),
4997                                 Cond);
4998           if (N0.getOperand(1) == N1.getOperand(0))
4999             return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(1),
5000                                 Cond);
5001         }
5002       }
5003 
5004       // If RHS is a legal immediate value for a compare instruction, we need
5005       // to be careful about increasing register pressure needlessly.
5006       bool LegalRHSImm = false;
5007 
5008       if (auto *RHSC = dyn_cast<ConstantSDNode>(N1)) {
5009         if (auto *LHSR = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5010           // Turn (X+C1) == C2 --> X == C2-C1
5011           if (N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse())
5012             return DAG.getSetCC(
5013                 dl, VT, N0.getOperand(0),
5014                 DAG.getConstant(RHSC->getAPIntValue() - LHSR->getAPIntValue(),
5015                                 dl, N0.getValueType()),
5016                 Cond);
5017 
5018           // Turn (X^C1) == C2 --> X == C1^C2
5019           if (N0.getOpcode() == ISD::XOR && N0.getNode()->hasOneUse())
5020             return DAG.getSetCC(
5021                 dl, VT, N0.getOperand(0),
5022                 DAG.getConstant(LHSR->getAPIntValue() ^ RHSC->getAPIntValue(),
5023                                 dl, N0.getValueType()),
5024                 Cond);
5025         }
5026 
5027         // Turn (C1-X) == C2 --> X == C1-C2
5028         if (auto *SUBC = dyn_cast<ConstantSDNode>(N0.getOperand(0)))
5029           if (N0.getOpcode() == ISD::SUB && N0.getNode()->hasOneUse())
5030             return DAG.getSetCC(
5031                 dl, VT, N0.getOperand(1),
5032                 DAG.getConstant(SUBC->getAPIntValue() - RHSC->getAPIntValue(),
5033                                 dl, N0.getValueType()),
5034                 Cond);
5035 
5036         // Could RHSC fold directly into a compare?
5037         if (RHSC->getValueType(0).getSizeInBits() <= 64)
5038           LegalRHSImm = isLegalICmpImmediate(RHSC->getSExtValue());
5039       }
5040 
5041       // (X+Y) == X --> Y == 0 and similar folds.
5042       // Don't do this if X is an immediate that can fold into a cmp
5043       // instruction and X+Y has other uses. It could be an induction variable
5044       // chain, and the transform would increase register pressure.
5045       if (!LegalRHSImm || N0.hasOneUse())
5046         if (SDValue V = foldSetCCWithBinOp(VT, N0, N1, Cond, dl, DCI))
5047           return V;
5048     }
5049 
5050     if (N1.getOpcode() == ISD::ADD || N1.getOpcode() == ISD::SUB ||
5051         N1.getOpcode() == ISD::XOR)
5052       if (SDValue V = foldSetCCWithBinOp(VT, N1, N0, Cond, dl, DCI))
5053         return V;
5054 
5055     if (SDValue V = foldSetCCWithAnd(VT, N0, N1, Cond, dl, DCI))
5056       return V;
5057   }
5058 
5059   // Fold remainder of division by a constant.
5060   if ((N0.getOpcode() == ISD::UREM || N0.getOpcode() == ISD::SREM) &&
5061       N0.hasOneUse() && (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
5062     // When division is cheap or optimizing for minimum size,
5063     // fall through to DIVREM creation by skipping this fold.
5064     if (!isIntDivCheap(VT, Attr) && !Attr.hasFnAttr(Attribute::MinSize)) {
5065       if (N0.getOpcode() == ISD::UREM) {
5066         if (SDValue Folded = buildUREMEqFold(VT, N0, N1, Cond, DCI, dl))
5067           return Folded;
5068       } else if (N0.getOpcode() == ISD::SREM) {
5069         if (SDValue Folded = buildSREMEqFold(VT, N0, N1, Cond, DCI, dl))
5070           return Folded;
5071       }
5072     }
5073   }
5074 
5075   // Fold away ALL boolean setcc's.
5076   if (N0.getValueType().getScalarType() == MVT::i1 && foldBooleans) {
5077     SDValue Temp;
5078     switch (Cond) {
5079     default: llvm_unreachable("Unknown integer setcc!");
5080     case ISD::SETEQ:  // X == Y  -> ~(X^Y)
5081       Temp = DAG.getNode(ISD::XOR, dl, OpVT, N0, N1);
5082       N0 = DAG.getNOT(dl, Temp, OpVT);
5083       if (!DCI.isCalledByLegalizer())
5084         DCI.AddToWorklist(Temp.getNode());
5085       break;
5086     case ISD::SETNE:  // X != Y   -->  (X^Y)
5087       N0 = DAG.getNode(ISD::XOR, dl, OpVT, N0, N1);
5088       break;
5089     case ISD::SETGT:  // X >s Y   -->  X == 0 & Y == 1  -->  ~X & Y
5090     case ISD::SETULT: // X <u Y   -->  X == 0 & Y == 1  -->  ~X & Y
5091       Temp = DAG.getNOT(dl, N0, OpVT);
5092       N0 = DAG.getNode(ISD::AND, dl, OpVT, N1, Temp);
5093       if (!DCI.isCalledByLegalizer())
5094         DCI.AddToWorklist(Temp.getNode());
5095       break;
5096     case ISD::SETLT:  // X <s Y   --> X == 1 & Y == 0  -->  ~Y & X
5097     case ISD::SETUGT: // X >u Y   --> X == 1 & Y == 0  -->  ~Y & X
5098       Temp = DAG.getNOT(dl, N1, OpVT);
5099       N0 = DAG.getNode(ISD::AND, dl, OpVT, N0, Temp);
5100       if (!DCI.isCalledByLegalizer())
5101         DCI.AddToWorklist(Temp.getNode());
5102       break;
5103     case ISD::SETULE: // X <=u Y  --> X == 0 | Y == 1  -->  ~X | Y
5104     case ISD::SETGE:  // X >=s Y  --> X == 0 | Y == 1  -->  ~X | Y
5105       Temp = DAG.getNOT(dl, N0, OpVT);
5106       N0 = DAG.getNode(ISD::OR, dl, OpVT, N1, Temp);
5107       if (!DCI.isCalledByLegalizer())
5108         DCI.AddToWorklist(Temp.getNode());
5109       break;
5110     case ISD::SETUGE: // X >=u Y  --> X == 1 | Y == 0  -->  ~Y | X
5111     case ISD::SETLE:  // X <=s Y  --> X == 1 | Y == 0  -->  ~Y | X
5112       Temp = DAG.getNOT(dl, N1, OpVT);
5113       N0 = DAG.getNode(ISD::OR, dl, OpVT, N0, Temp);
5114       break;
5115     }
5116     if (VT.getScalarType() != MVT::i1) {
5117       if (!DCI.isCalledByLegalizer())
5118         DCI.AddToWorklist(N0.getNode());
5119       // FIXME: If running after legalize, we probably can't do this.
5120       ISD::NodeType ExtendCode = getExtendForContent(getBooleanContents(OpVT));
5121       N0 = DAG.getNode(ExtendCode, dl, VT, N0);
5122     }
5123     return N0;
5124   }
5125 
5126   // Could not fold it.
5127   return SDValue();
5128 }
5129 
5130 /// Returns true (and the GlobalValue and the offset) if the node is a
5131 /// GlobalAddress + offset.
isGAPlusOffset(SDNode * WN,const GlobalValue * & GA,int64_t & Offset) const5132 bool TargetLowering::isGAPlusOffset(SDNode *WN, const GlobalValue *&GA,
5133                                     int64_t &Offset) const {
5134 
5135   SDNode *N = unwrapAddress(SDValue(WN, 0)).getNode();
5136 
5137   if (auto *GASD = dyn_cast<GlobalAddressSDNode>(N)) {
5138     GA = GASD->getGlobal();
5139     Offset += GASD->getOffset();
5140     return true;
5141   }
5142 
5143   if (N->getOpcode() == ISD::ADD) {
5144     SDValue N1 = N->getOperand(0);
5145     SDValue N2 = N->getOperand(1);
5146     if (isGAPlusOffset(N1.getNode(), GA, Offset)) {
5147       if (auto *V = dyn_cast<ConstantSDNode>(N2)) {
5148         Offset += V->getSExtValue();
5149         return true;
5150       }
5151     } else if (isGAPlusOffset(N2.getNode(), GA, Offset)) {
5152       if (auto *V = dyn_cast<ConstantSDNode>(N1)) {
5153         Offset += V->getSExtValue();
5154         return true;
5155       }
5156     }
5157   }
5158 
5159   return false;
5160 }
5161 
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const5162 SDValue TargetLowering::PerformDAGCombine(SDNode *N,
5163                                           DAGCombinerInfo &DCI) const {
5164   // Default implementation: no optimization.
5165   return SDValue();
5166 }
5167 
5168 //===----------------------------------------------------------------------===//
5169 //  Inline Assembler Implementation Methods
5170 //===----------------------------------------------------------------------===//
5171 
5172 TargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const5173 TargetLowering::getConstraintType(StringRef Constraint) const {
5174   unsigned S = Constraint.size();
5175 
5176   if (S == 1) {
5177     switch (Constraint[0]) {
5178     default: break;
5179     case 'r':
5180       return C_RegisterClass;
5181     case 'm': // memory
5182     case 'o': // offsetable
5183     case 'V': // not offsetable
5184       return C_Memory;
5185     case 'p': // Address.
5186       return C_Address;
5187     case 'n': // Simple Integer
5188     case 'E': // Floating Point Constant
5189     case 'F': // Floating Point Constant
5190       return C_Immediate;
5191     case 'i': // Simple Integer or Relocatable Constant
5192     case 's': // Relocatable Constant
5193     case 'X': // Allow ANY value.
5194     case 'I': // Target registers.
5195     case 'J':
5196     case 'K':
5197     case 'L':
5198     case 'M':
5199     case 'N':
5200     case 'O':
5201     case 'P':
5202     case '<':
5203     case '>':
5204       return C_Other;
5205     }
5206   }
5207 
5208   if (S > 1 && Constraint[0] == '{' && Constraint[S - 1] == '}') {
5209     if (S == 8 && Constraint.substr(1, 6) == "memory") // "{memory}"
5210       return C_Memory;
5211     return C_Register;
5212   }
5213   return C_Unknown;
5214 }
5215 
5216 /// Try to replace an X constraint, which matches anything, with another that
5217 /// has more specific requirements based on the type of the corresponding
5218 /// operand.
LowerXConstraint(EVT ConstraintVT) const5219 const char *TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
5220   if (ConstraintVT.isInteger())
5221     return "r";
5222   if (ConstraintVT.isFloatingPoint())
5223     return "f"; // works for many targets
5224   return nullptr;
5225 }
5226 
LowerAsmOutputForConstraint(SDValue & Chain,SDValue & Flag,const SDLoc & DL,const AsmOperandInfo & OpInfo,SelectionDAG & DAG) const5227 SDValue TargetLowering::LowerAsmOutputForConstraint(
5228     SDValue &Chain, SDValue &Flag, const SDLoc &DL,
5229     const AsmOperandInfo &OpInfo, SelectionDAG &DAG) const {
5230   return SDValue();
5231 }
5232 
5233 /// Lower the specified operand into the Ops vector.
5234 /// If it is invalid, don't add anything to Ops.
LowerAsmOperandForConstraint(SDValue Op,std::string & Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const5235 void TargetLowering::LowerAsmOperandForConstraint(SDValue Op,
5236                                                   std::string &Constraint,
5237                                                   std::vector<SDValue> &Ops,
5238                                                   SelectionDAG &DAG) const {
5239 
5240   if (Constraint.length() > 1) return;
5241 
5242   char ConstraintLetter = Constraint[0];
5243   switch (ConstraintLetter) {
5244   default: break;
5245   case 'X':    // Allows any operand
5246   case 'i':    // Simple Integer or Relocatable Constant
5247   case 'n':    // Simple Integer
5248   case 's': {  // Relocatable Constant
5249 
5250     ConstantSDNode *C;
5251     uint64_t Offset = 0;
5252 
5253     // Match (GA) or (C) or (GA+C) or (GA-C) or ((GA+C)+C) or (((GA+C)+C)+C),
5254     // etc., since getelementpointer is variadic. We can't use
5255     // SelectionDAG::FoldSymbolOffset because it expects the GA to be accessible
5256     // while in this case the GA may be furthest from the root node which is
5257     // likely an ISD::ADD.
5258     while (true) {
5259       if ((C = dyn_cast<ConstantSDNode>(Op)) && ConstraintLetter != 's') {
5260         // gcc prints these as sign extended.  Sign extend value to 64 bits
5261         // now; without this it would get ZExt'd later in
5262         // ScheduleDAGSDNodes::EmitNode, which is very generic.
5263         bool IsBool = C->getConstantIntValue()->getBitWidth() == 1;
5264         BooleanContent BCont = getBooleanContents(MVT::i64);
5265         ISD::NodeType ExtOpc =
5266             IsBool ? getExtendForContent(BCont) : ISD::SIGN_EXTEND;
5267         int64_t ExtVal =
5268             ExtOpc == ISD::ZERO_EXTEND ? C->getZExtValue() : C->getSExtValue();
5269         Ops.push_back(
5270             DAG.getTargetConstant(Offset + ExtVal, SDLoc(C), MVT::i64));
5271         return;
5272       }
5273       if (ConstraintLetter != 'n') {
5274         if (const auto *GA = dyn_cast<GlobalAddressSDNode>(Op)) {
5275           Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op),
5276                                                    GA->getValueType(0),
5277                                                    Offset + GA->getOffset()));
5278           return;
5279         }
5280         if (const auto *BA = dyn_cast<BlockAddressSDNode>(Op)) {
5281           Ops.push_back(DAG.getTargetBlockAddress(
5282               BA->getBlockAddress(), BA->getValueType(0),
5283               Offset + BA->getOffset(), BA->getTargetFlags()));
5284           return;
5285         }
5286         if (isa<BasicBlockSDNode>(Op)) {
5287           Ops.push_back(Op);
5288           return;
5289         }
5290       }
5291       const unsigned OpCode = Op.getOpcode();
5292       if (OpCode == ISD::ADD || OpCode == ISD::SUB) {
5293         if ((C = dyn_cast<ConstantSDNode>(Op.getOperand(0))))
5294           Op = Op.getOperand(1);
5295         // Subtraction is not commutative.
5296         else if (OpCode == ISD::ADD &&
5297                  (C = dyn_cast<ConstantSDNode>(Op.getOperand(1))))
5298           Op = Op.getOperand(0);
5299         else
5300           return;
5301         Offset += (OpCode == ISD::ADD ? 1 : -1) * C->getSExtValue();
5302         continue;
5303       }
5304       return;
5305     }
5306     break;
5307   }
5308   }
5309 }
5310 
CollectTargetIntrinsicOperands(const CallInst & I,SmallVectorImpl<SDValue> & Ops,SelectionDAG & DAG) const5311 void TargetLowering::CollectTargetIntrinsicOperands(const CallInst &I,
5312                                            SmallVectorImpl<SDValue> &Ops,
5313                                            SelectionDAG &DAG) const {
5314   return;
5315 }
5316 
5317 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * RI,StringRef Constraint,MVT VT) const5318 TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI,
5319                                              StringRef Constraint,
5320                                              MVT VT) const {
5321   if (Constraint.empty() || Constraint[0] != '{')
5322     return std::make_pair(0u, static_cast<TargetRegisterClass *>(nullptr));
5323   assert(*(Constraint.end() - 1) == '}' && "Not a brace enclosed constraint?");
5324 
5325   // Remove the braces from around the name.
5326   StringRef RegName(Constraint.data() + 1, Constraint.size() - 2);
5327 
5328   std::pair<unsigned, const TargetRegisterClass *> R =
5329       std::make_pair(0u, static_cast<const TargetRegisterClass *>(nullptr));
5330 
5331   // Figure out which register class contains this reg.
5332   for (const TargetRegisterClass *RC : RI->regclasses()) {
5333     // If none of the value types for this register class are valid, we
5334     // can't use it.  For example, 64-bit reg classes on 32-bit targets.
5335     if (!isLegalRC(*RI, *RC))
5336       continue;
5337 
5338     for (const MCPhysReg &PR : *RC) {
5339       if (RegName.equals_insensitive(RI->getRegAsmName(PR))) {
5340         std::pair<unsigned, const TargetRegisterClass *> S =
5341             std::make_pair(PR, RC);
5342 
5343         // If this register class has the requested value type, return it,
5344         // otherwise keep searching and return the first class found
5345         // if no other is found which explicitly has the requested type.
5346         if (RI->isTypeLegalForClass(*RC, VT))
5347           return S;
5348         if (!R.second)
5349           R = S;
5350       }
5351     }
5352   }
5353 
5354   return R;
5355 }
5356 
5357 //===----------------------------------------------------------------------===//
5358 // Constraint Selection.
5359 
5360 /// Return true of this is an input operand that is a matching constraint like
5361 /// "4".
isMatchingInputConstraint() const5362 bool TargetLowering::AsmOperandInfo::isMatchingInputConstraint() const {
5363   assert(!ConstraintCode.empty() && "No known constraint!");
5364   return isdigit(static_cast<unsigned char>(ConstraintCode[0]));
5365 }
5366 
5367 /// If this is an input matching constraint, this method returns the output
5368 /// operand it matches.
getMatchedOperand() const5369 unsigned TargetLowering::AsmOperandInfo::getMatchedOperand() const {
5370   assert(!ConstraintCode.empty() && "No known constraint!");
5371   return atoi(ConstraintCode.c_str());
5372 }
5373 
5374 /// Split up the constraint string from the inline assembly value into the
5375 /// specific constraints and their prefixes, and also tie in the associated
5376 /// operand values.
5377 /// If this returns an empty vector, and if the constraint string itself
5378 /// isn't empty, there was an error parsing.
5379 TargetLowering::AsmOperandInfoVector
ParseConstraints(const DataLayout & DL,const TargetRegisterInfo * TRI,const CallBase & Call) const5380 TargetLowering::ParseConstraints(const DataLayout &DL,
5381                                  const TargetRegisterInfo *TRI,
5382                                  const CallBase &Call) const {
5383   /// Information about all of the constraints.
5384   AsmOperandInfoVector ConstraintOperands;
5385   const InlineAsm *IA = cast<InlineAsm>(Call.getCalledOperand());
5386   unsigned maCount = 0; // Largest number of multiple alternative constraints.
5387 
5388   // Do a prepass over the constraints, canonicalizing them, and building up the
5389   // ConstraintOperands list.
5390   unsigned ArgNo = 0; // ArgNo - The argument of the CallInst.
5391   unsigned ResNo = 0; // ResNo - The result number of the next output.
5392   unsigned LabelNo = 0; // LabelNo - CallBr indirect dest number.
5393 
5394   for (InlineAsm::ConstraintInfo &CI : IA->ParseConstraints()) {
5395     ConstraintOperands.emplace_back(std::move(CI));
5396     AsmOperandInfo &OpInfo = ConstraintOperands.back();
5397 
5398     // Update multiple alternative constraint count.
5399     if (OpInfo.multipleAlternatives.size() > maCount)
5400       maCount = OpInfo.multipleAlternatives.size();
5401 
5402     OpInfo.ConstraintVT = MVT::Other;
5403 
5404     // Compute the value type for each operand.
5405     switch (OpInfo.Type) {
5406     case InlineAsm::isOutput:
5407       // Indirect outputs just consume an argument.
5408       if (OpInfo.isIndirect) {
5409         OpInfo.CallOperandVal = Call.getArgOperand(ArgNo);
5410         break;
5411       }
5412 
5413       // The return value of the call is this value.  As such, there is no
5414       // corresponding argument.
5415       assert(!Call.getType()->isVoidTy() && "Bad inline asm!");
5416       if (StructType *STy = dyn_cast<StructType>(Call.getType())) {
5417         OpInfo.ConstraintVT =
5418             getSimpleValueType(DL, STy->getElementType(ResNo));
5419       } else {
5420         assert(ResNo == 0 && "Asm only has one result!");
5421         OpInfo.ConstraintVT =
5422             getAsmOperandValueType(DL, Call.getType()).getSimpleVT();
5423       }
5424       ++ResNo;
5425       break;
5426     case InlineAsm::isInput:
5427       OpInfo.CallOperandVal = Call.getArgOperand(ArgNo);
5428       break;
5429     case InlineAsm::isLabel:
5430       OpInfo.CallOperandVal = cast<CallBrInst>(&Call)->getIndirectDest(LabelNo);
5431       ++LabelNo;
5432       continue;
5433     case InlineAsm::isClobber:
5434       // Nothing to do.
5435       break;
5436     }
5437 
5438     if (OpInfo.CallOperandVal) {
5439       llvm::Type *OpTy = OpInfo.CallOperandVal->getType();
5440       if (OpInfo.isIndirect) {
5441         OpTy = Call.getParamElementType(ArgNo);
5442         assert(OpTy && "Indirect operand must have elementtype attribute");
5443       }
5444 
5445       // Look for vector wrapped in a struct. e.g. { <16 x i8> }.
5446       if (StructType *STy = dyn_cast<StructType>(OpTy))
5447         if (STy->getNumElements() == 1)
5448           OpTy = STy->getElementType(0);
5449 
5450       // If OpTy is not a single value, it may be a struct/union that we
5451       // can tile with integers.
5452       if (!OpTy->isSingleValueType() && OpTy->isSized()) {
5453         unsigned BitSize = DL.getTypeSizeInBits(OpTy);
5454         switch (BitSize) {
5455         default: break;
5456         case 1:
5457         case 8:
5458         case 16:
5459         case 32:
5460         case 64:
5461         case 128:
5462           OpTy = IntegerType::get(OpTy->getContext(), BitSize);
5463           break;
5464         }
5465       }
5466 
5467       EVT VT = getAsmOperandValueType(DL, OpTy, true);
5468       OpInfo.ConstraintVT = VT.isSimple() ? VT.getSimpleVT() : MVT::Other;
5469       ArgNo++;
5470     }
5471   }
5472 
5473   // If we have multiple alternative constraints, select the best alternative.
5474   if (!ConstraintOperands.empty()) {
5475     if (maCount) {
5476       unsigned bestMAIndex = 0;
5477       int bestWeight = -1;
5478       // weight:  -1 = invalid match, and 0 = so-so match to 5 = good match.
5479       int weight = -1;
5480       unsigned maIndex;
5481       // Compute the sums of the weights for each alternative, keeping track
5482       // of the best (highest weight) one so far.
5483       for (maIndex = 0; maIndex < maCount; ++maIndex) {
5484         int weightSum = 0;
5485         for (unsigned cIndex = 0, eIndex = ConstraintOperands.size();
5486              cIndex != eIndex; ++cIndex) {
5487           AsmOperandInfo &OpInfo = ConstraintOperands[cIndex];
5488           if (OpInfo.Type == InlineAsm::isClobber)
5489             continue;
5490 
5491           // If this is an output operand with a matching input operand,
5492           // look up the matching input. If their types mismatch, e.g. one
5493           // is an integer, the other is floating point, or their sizes are
5494           // different, flag it as an maCantMatch.
5495           if (OpInfo.hasMatchingInput()) {
5496             AsmOperandInfo &Input = ConstraintOperands[OpInfo.MatchingInput];
5497             if (OpInfo.ConstraintVT != Input.ConstraintVT) {
5498               if ((OpInfo.ConstraintVT.isInteger() !=
5499                    Input.ConstraintVT.isInteger()) ||
5500                   (OpInfo.ConstraintVT.getSizeInBits() !=
5501                    Input.ConstraintVT.getSizeInBits())) {
5502                 weightSum = -1; // Can't match.
5503                 break;
5504               }
5505             }
5506           }
5507           weight = getMultipleConstraintMatchWeight(OpInfo, maIndex);
5508           if (weight == -1) {
5509             weightSum = -1;
5510             break;
5511           }
5512           weightSum += weight;
5513         }
5514         // Update best.
5515         if (weightSum > bestWeight) {
5516           bestWeight = weightSum;
5517           bestMAIndex = maIndex;
5518         }
5519       }
5520 
5521       // Now select chosen alternative in each constraint.
5522       for (AsmOperandInfo &cInfo : ConstraintOperands)
5523         if (cInfo.Type != InlineAsm::isClobber)
5524           cInfo.selectAlternative(bestMAIndex);
5525     }
5526   }
5527 
5528   // Check and hook up tied operands, choose constraint code to use.
5529   for (unsigned cIndex = 0, eIndex = ConstraintOperands.size();
5530        cIndex != eIndex; ++cIndex) {
5531     AsmOperandInfo &OpInfo = ConstraintOperands[cIndex];
5532 
5533     // If this is an output operand with a matching input operand, look up the
5534     // matching input. If their types mismatch, e.g. one is an integer, the
5535     // other is floating point, or their sizes are different, flag it as an
5536     // error.
5537     if (OpInfo.hasMatchingInput()) {
5538       AsmOperandInfo &Input = ConstraintOperands[OpInfo.MatchingInput];
5539 
5540       if (OpInfo.ConstraintVT != Input.ConstraintVT) {
5541         std::pair<unsigned, const TargetRegisterClass *> MatchRC =
5542             getRegForInlineAsmConstraint(TRI, OpInfo.ConstraintCode,
5543                                          OpInfo.ConstraintVT);
5544         std::pair<unsigned, const TargetRegisterClass *> InputRC =
5545             getRegForInlineAsmConstraint(TRI, Input.ConstraintCode,
5546                                          Input.ConstraintVT);
5547         if ((OpInfo.ConstraintVT.isInteger() !=
5548              Input.ConstraintVT.isInteger()) ||
5549             (MatchRC.second != InputRC.second)) {
5550           report_fatal_error("Unsupported asm: input constraint"
5551                              " with a matching output constraint of"
5552                              " incompatible type!");
5553         }
5554       }
5555     }
5556   }
5557 
5558   return ConstraintOperands;
5559 }
5560 
5561 /// Return an integer indicating how general CT is.
getConstraintGenerality(TargetLowering::ConstraintType CT)5562 static unsigned getConstraintGenerality(TargetLowering::ConstraintType CT) {
5563   switch (CT) {
5564   case TargetLowering::C_Immediate:
5565   case TargetLowering::C_Other:
5566   case TargetLowering::C_Unknown:
5567     return 0;
5568   case TargetLowering::C_Register:
5569     return 1;
5570   case TargetLowering::C_RegisterClass:
5571     return 2;
5572   case TargetLowering::C_Memory:
5573   case TargetLowering::C_Address:
5574     return 3;
5575   }
5576   llvm_unreachable("Invalid constraint type");
5577 }
5578 
5579 /// Examine constraint type and operand type and determine a weight value.
5580 /// This object must already have been set up with the operand type
5581 /// and the current alternative constraint selected.
5582 TargetLowering::ConstraintWeight
getMultipleConstraintMatchWeight(AsmOperandInfo & info,int maIndex) const5583   TargetLowering::getMultipleConstraintMatchWeight(
5584     AsmOperandInfo &info, int maIndex) const {
5585   InlineAsm::ConstraintCodeVector *rCodes;
5586   if (maIndex >= (int)info.multipleAlternatives.size())
5587     rCodes = &info.Codes;
5588   else
5589     rCodes = &info.multipleAlternatives[maIndex].Codes;
5590   ConstraintWeight BestWeight = CW_Invalid;
5591 
5592   // Loop over the options, keeping track of the most general one.
5593   for (const std::string &rCode : *rCodes) {
5594     ConstraintWeight weight =
5595         getSingleConstraintMatchWeight(info, rCode.c_str());
5596     if (weight > BestWeight)
5597       BestWeight = weight;
5598   }
5599 
5600   return BestWeight;
5601 }
5602 
5603 /// Examine constraint type and operand type and determine a weight value.
5604 /// This object must already have been set up with the operand type
5605 /// and the current alternative constraint selected.
5606 TargetLowering::ConstraintWeight
getSingleConstraintMatchWeight(AsmOperandInfo & info,const char * constraint) const5607   TargetLowering::getSingleConstraintMatchWeight(
5608     AsmOperandInfo &info, const char *constraint) const {
5609   ConstraintWeight weight = CW_Invalid;
5610   Value *CallOperandVal = info.CallOperandVal;
5611     // If we don't have a value, we can't do a match,
5612     // but allow it at the lowest weight.
5613   if (!CallOperandVal)
5614     return CW_Default;
5615   // Look at the constraint type.
5616   switch (*constraint) {
5617     case 'i': // immediate integer.
5618     case 'n': // immediate integer with a known value.
5619       if (isa<ConstantInt>(CallOperandVal))
5620         weight = CW_Constant;
5621       break;
5622     case 's': // non-explicit intregal immediate.
5623       if (isa<GlobalValue>(CallOperandVal))
5624         weight = CW_Constant;
5625       break;
5626     case 'E': // immediate float if host format.
5627     case 'F': // immediate float.
5628       if (isa<ConstantFP>(CallOperandVal))
5629         weight = CW_Constant;
5630       break;
5631     case '<': // memory operand with autodecrement.
5632     case '>': // memory operand with autoincrement.
5633     case 'm': // memory operand.
5634     case 'o': // offsettable memory operand
5635     case 'V': // non-offsettable memory operand
5636       weight = CW_Memory;
5637       break;
5638     case 'r': // general register.
5639     case 'g': // general register, memory operand or immediate integer.
5640               // note: Clang converts "g" to "imr".
5641       if (CallOperandVal->getType()->isIntegerTy())
5642         weight = CW_Register;
5643       break;
5644     case 'X': // any operand.
5645   default:
5646     weight = CW_Default;
5647     break;
5648   }
5649   return weight;
5650 }
5651 
5652 /// If there are multiple different constraints that we could pick for this
5653 /// operand (e.g. "imr") try to pick the 'best' one.
5654 /// This is somewhat tricky: constraints fall into four classes:
5655 ///    Other         -> immediates and magic values
5656 ///    Register      -> one specific register
5657 ///    RegisterClass -> a group of regs
5658 ///    Memory        -> memory
5659 /// Ideally, we would pick the most specific constraint possible: if we have
5660 /// something that fits into a register, we would pick it.  The problem here
5661 /// is that if we have something that could either be in a register or in
5662 /// memory that use of the register could cause selection of *other*
5663 /// operands to fail: they might only succeed if we pick memory.  Because of
5664 /// this the heuristic we use is:
5665 ///
5666 ///  1) If there is an 'other' constraint, and if the operand is valid for
5667 ///     that constraint, use it.  This makes us take advantage of 'i'
5668 ///     constraints when available.
5669 ///  2) Otherwise, pick the most general constraint present.  This prefers
5670 ///     'm' over 'r', for example.
5671 ///
ChooseConstraint(TargetLowering::AsmOperandInfo & OpInfo,const TargetLowering & TLI,SDValue Op,SelectionDAG * DAG)5672 static void ChooseConstraint(TargetLowering::AsmOperandInfo &OpInfo,
5673                              const TargetLowering &TLI,
5674                              SDValue Op, SelectionDAG *DAG) {
5675   assert(OpInfo.Codes.size() > 1 && "Doesn't have multiple constraint options");
5676   unsigned BestIdx = 0;
5677   TargetLowering::ConstraintType BestType = TargetLowering::C_Unknown;
5678   int BestGenerality = -1;
5679 
5680   // Loop over the options, keeping track of the most general one.
5681   for (unsigned i = 0, e = OpInfo.Codes.size(); i != e; ++i) {
5682     TargetLowering::ConstraintType CType =
5683       TLI.getConstraintType(OpInfo.Codes[i]);
5684 
5685     // Indirect 'other' or 'immediate' constraints are not allowed.
5686     if (OpInfo.isIndirect && !(CType == TargetLowering::C_Memory ||
5687                                CType == TargetLowering::C_Register ||
5688                                CType == TargetLowering::C_RegisterClass))
5689       continue;
5690 
5691     // If this is an 'other' or 'immediate' constraint, see if the operand is
5692     // valid for it. For example, on X86 we might have an 'rI' constraint. If
5693     // the operand is an integer in the range [0..31] we want to use I (saving a
5694     // load of a register), otherwise we must use 'r'.
5695     if ((CType == TargetLowering::C_Other ||
5696          CType == TargetLowering::C_Immediate) && Op.getNode()) {
5697       assert(OpInfo.Codes[i].size() == 1 &&
5698              "Unhandled multi-letter 'other' constraint");
5699       std::vector<SDValue> ResultOps;
5700       TLI.LowerAsmOperandForConstraint(Op, OpInfo.Codes[i],
5701                                        ResultOps, *DAG);
5702       if (!ResultOps.empty()) {
5703         BestType = CType;
5704         BestIdx = i;
5705         break;
5706       }
5707     }
5708 
5709     // Things with matching constraints can only be registers, per gcc
5710     // documentation.  This mainly affects "g" constraints.
5711     if (CType == TargetLowering::C_Memory && OpInfo.hasMatchingInput())
5712       continue;
5713 
5714     // This constraint letter is more general than the previous one, use it.
5715     int Generality = getConstraintGenerality(CType);
5716     if (Generality > BestGenerality) {
5717       BestType = CType;
5718       BestIdx = i;
5719       BestGenerality = Generality;
5720     }
5721   }
5722 
5723   OpInfo.ConstraintCode = OpInfo.Codes[BestIdx];
5724   OpInfo.ConstraintType = BestType;
5725 }
5726 
5727 /// Determines the constraint code and constraint type to use for the specific
5728 /// AsmOperandInfo, setting OpInfo.ConstraintCode and OpInfo.ConstraintType.
ComputeConstraintToUse(AsmOperandInfo & OpInfo,SDValue Op,SelectionDAG * DAG) const5729 void TargetLowering::ComputeConstraintToUse(AsmOperandInfo &OpInfo,
5730                                             SDValue Op,
5731                                             SelectionDAG *DAG) const {
5732   assert(!OpInfo.Codes.empty() && "Must have at least one constraint");
5733 
5734   // Single-letter constraints ('r') are very common.
5735   if (OpInfo.Codes.size() == 1) {
5736     OpInfo.ConstraintCode = OpInfo.Codes[0];
5737     OpInfo.ConstraintType = getConstraintType(OpInfo.ConstraintCode);
5738   } else {
5739     ChooseConstraint(OpInfo, *this, Op, DAG);
5740   }
5741 
5742   // 'X' matches anything.
5743   if (OpInfo.ConstraintCode == "X" && OpInfo.CallOperandVal) {
5744     // Constants are handled elsewhere.  For Functions, the type here is the
5745     // type of the result, which is not what we want to look at; leave them
5746     // alone.
5747     Value *v = OpInfo.CallOperandVal;
5748     if (isa<ConstantInt>(v) || isa<Function>(v)) {
5749       return;
5750     }
5751 
5752     if (isa<BasicBlock>(v) || isa<BlockAddress>(v)) {
5753       OpInfo.ConstraintCode = "i";
5754       return;
5755     }
5756 
5757     // Otherwise, try to resolve it to something we know about by looking at
5758     // the actual operand type.
5759     if (const char *Repl = LowerXConstraint(OpInfo.ConstraintVT)) {
5760       OpInfo.ConstraintCode = Repl;
5761       OpInfo.ConstraintType = getConstraintType(OpInfo.ConstraintCode);
5762     }
5763   }
5764 }
5765 
5766 /// Given an exact SDIV by a constant, create a multiplication
5767 /// with the multiplicative inverse of the constant.
BuildExactSDIV(const TargetLowering & TLI,SDNode * N,const SDLoc & dl,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created)5768 static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
5769                               const SDLoc &dl, SelectionDAG &DAG,
5770                               SmallVectorImpl<SDNode *> &Created) {
5771   SDValue Op0 = N->getOperand(0);
5772   SDValue Op1 = N->getOperand(1);
5773   EVT VT = N->getValueType(0);
5774   EVT SVT = VT.getScalarType();
5775   EVT ShVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
5776   EVT ShSVT = ShVT.getScalarType();
5777 
5778   bool UseSRA = false;
5779   SmallVector<SDValue, 16> Shifts, Factors;
5780 
5781   auto BuildSDIVPattern = [&](ConstantSDNode *C) {
5782     if (C->isZero())
5783       return false;
5784     APInt Divisor = C->getAPIntValue();
5785     unsigned Shift = Divisor.countTrailingZeros();
5786     if (Shift) {
5787       Divisor.ashrInPlace(Shift);
5788       UseSRA = true;
5789     }
5790     // Calculate the multiplicative inverse, using Newton's method.
5791     APInt t;
5792     APInt Factor = Divisor;
5793     while ((t = Divisor * Factor) != 1)
5794       Factor *= APInt(Divisor.getBitWidth(), 2) - t;
5795     Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
5796     Factors.push_back(DAG.getConstant(Factor, dl, SVT));
5797     return true;
5798   };
5799 
5800   // Collect all magic values from the build vector.
5801   if (!ISD::matchUnaryPredicate(Op1, BuildSDIVPattern))
5802     return SDValue();
5803 
5804   SDValue Shift, Factor;
5805   if (Op1.getOpcode() == ISD::BUILD_VECTOR) {
5806     Shift = DAG.getBuildVector(ShVT, dl, Shifts);
5807     Factor = DAG.getBuildVector(VT, dl, Factors);
5808   } else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) {
5809     assert(Shifts.size() == 1 && Factors.size() == 1 &&
5810            "Expected matchUnaryPredicate to return one element for scalable "
5811            "vectors");
5812     Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
5813     Factor = DAG.getSplatVector(VT, dl, Factors[0]);
5814   } else {
5815     assert(isa<ConstantSDNode>(Op1) && "Expected a constant");
5816     Shift = Shifts[0];
5817     Factor = Factors[0];
5818   }
5819 
5820   SDValue Res = Op0;
5821 
5822   // Shift the value upfront if it is even, so the LSB is one.
5823   if (UseSRA) {
5824     // TODO: For UDIV use SRL instead of SRA.
5825     SDNodeFlags Flags;
5826     Flags.setExact(true);
5827     Res = DAG.getNode(ISD::SRA, dl, VT, Res, Shift, Flags);
5828     Created.push_back(Res.getNode());
5829   }
5830 
5831   return DAG.getNode(ISD::MUL, dl, VT, Res, Factor);
5832 }
5833 
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const5834 SDValue TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
5835                               SelectionDAG &DAG,
5836                               SmallVectorImpl<SDNode *> &Created) const {
5837   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5838   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5839   if (TLI.isIntDivCheap(N->getValueType(0), Attr))
5840     return SDValue(N, 0); // Lower SDIV as SDIV
5841   return SDValue();
5842 }
5843 
5844 SDValue
BuildSREMPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const5845 TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor,
5846                               SelectionDAG &DAG,
5847                               SmallVectorImpl<SDNode *> &Created) const {
5848   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5849   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5850   if (TLI.isIntDivCheap(N->getValueType(0), Attr))
5851     return SDValue(N, 0); // Lower SREM as SREM
5852   return SDValue();
5853 }
5854 
5855 /// Given an ISD::SDIV node expressing a divide by constant,
5856 /// return a DAG expression to select that will generate the same value by
5857 /// multiplying by a magic number.
5858 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N,SelectionDAG & DAG,bool IsAfterLegalization,SmallVectorImpl<SDNode * > & Created) const5859 SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
5860                                   bool IsAfterLegalization,
5861                                   SmallVectorImpl<SDNode *> &Created) const {
5862   SDLoc dl(N);
5863   EVT VT = N->getValueType(0);
5864   EVT SVT = VT.getScalarType();
5865   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
5866   EVT ShSVT = ShVT.getScalarType();
5867   unsigned EltBits = VT.getScalarSizeInBits();
5868   EVT MulVT;
5869 
5870   // Check to see if we can do this.
5871   // FIXME: We should be more aggressive here.
5872   if (!isTypeLegal(VT)) {
5873     // Limit this to simple scalars for now.
5874     if (VT.isVector() || !VT.isSimple())
5875       return SDValue();
5876 
5877     // If this type will be promoted to a large enough type with a legal
5878     // multiply operation, we can go ahead and do this transform.
5879     if (getTypeAction(VT.getSimpleVT()) != TypePromoteInteger)
5880       return SDValue();
5881 
5882     MulVT = getTypeToTransformTo(*DAG.getContext(), VT);
5883     if (MulVT.getSizeInBits() < (2 * EltBits) ||
5884         !isOperationLegal(ISD::MUL, MulVT))
5885       return SDValue();
5886   }
5887 
5888   // If the sdiv has an 'exact' bit we can use a simpler lowering.
5889   if (N->getFlags().hasExact())
5890     return BuildExactSDIV(*this, N, dl, DAG, Created);
5891 
5892   SmallVector<SDValue, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5893 
5894   auto BuildSDIVPattern = [&](ConstantSDNode *C) {
5895     if (C->isZero())
5896       return false;
5897 
5898     const APInt &Divisor = C->getAPIntValue();
5899     SignedDivisionByConstantInfo magics = SignedDivisionByConstantInfo::get(Divisor);
5900     int NumeratorFactor = 0;
5901     int ShiftMask = -1;
5902 
5903     if (Divisor.isOne() || Divisor.isAllOnes()) {
5904       // If d is +1/-1, we just multiply the numerator by +1/-1.
5905       NumeratorFactor = Divisor.getSExtValue();
5906       magics.Magic = 0;
5907       magics.ShiftAmount = 0;
5908       ShiftMask = 0;
5909     } else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
5910       // If d > 0 and m < 0, add the numerator.
5911       NumeratorFactor = 1;
5912     } else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
5913       // If d < 0 and m > 0, subtract the numerator.
5914       NumeratorFactor = -1;
5915     }
5916 
5917     MagicFactors.push_back(DAG.getConstant(magics.Magic, dl, SVT));
5918     Factors.push_back(DAG.getConstant(NumeratorFactor, dl, SVT));
5919     Shifts.push_back(DAG.getConstant(magics.ShiftAmount, dl, ShSVT));
5920     ShiftMasks.push_back(DAG.getConstant(ShiftMask, dl, SVT));
5921     return true;
5922   };
5923 
5924   SDValue N0 = N->getOperand(0);
5925   SDValue N1 = N->getOperand(1);
5926 
5927   // Collect the shifts / magic values from each element.
5928   if (!ISD::matchUnaryPredicate(N1, BuildSDIVPattern))
5929     return SDValue();
5930 
5931   SDValue MagicFactor, Factor, Shift, ShiftMask;
5932   if (N1.getOpcode() == ISD::BUILD_VECTOR) {
5933     MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors);
5934     Factor = DAG.getBuildVector(VT, dl, Factors);
5935     Shift = DAG.getBuildVector(ShVT, dl, Shifts);
5936     ShiftMask = DAG.getBuildVector(VT, dl, ShiftMasks);
5937   } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
5938     assert(MagicFactors.size() == 1 && Factors.size() == 1 &&
5939            Shifts.size() == 1 && ShiftMasks.size() == 1 &&
5940            "Expected matchUnaryPredicate to return one element for scalable "
5941            "vectors");
5942     MagicFactor = DAG.getSplatVector(VT, dl, MagicFactors[0]);
5943     Factor = DAG.getSplatVector(VT, dl, Factors[0]);
5944     Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
5945     ShiftMask = DAG.getSplatVector(VT, dl, ShiftMasks[0]);
5946   } else {
5947     assert(isa<ConstantSDNode>(N1) && "Expected a constant");
5948     MagicFactor = MagicFactors[0];
5949     Factor = Factors[0];
5950     Shift = Shifts[0];
5951     ShiftMask = ShiftMasks[0];
5952   }
5953 
5954   // Multiply the numerator (operand 0) by the magic value.
5955   // FIXME: We should support doing a MUL in a wider type.
5956   auto GetMULHS = [&](SDValue X, SDValue Y) {
5957     // If the type isn't legal, use a wider mul of the the type calculated
5958     // earlier.
5959     if (!isTypeLegal(VT)) {
5960       X = DAG.getNode(ISD::SIGN_EXTEND, dl, MulVT, X);
5961       Y = DAG.getNode(ISD::SIGN_EXTEND, dl, MulVT, Y);
5962       Y = DAG.getNode(ISD::MUL, dl, MulVT, X, Y);
5963       Y = DAG.getNode(ISD::SRL, dl, MulVT, Y,
5964                       DAG.getShiftAmountConstant(EltBits, MulVT, dl));
5965       return DAG.getNode(ISD::TRUNCATE, dl, VT, Y);
5966     }
5967 
5968     if (isOperationLegalOrCustom(ISD::MULHS, VT, IsAfterLegalization))
5969       return DAG.getNode(ISD::MULHS, dl, VT, X, Y);
5970     if (isOperationLegalOrCustom(ISD::SMUL_LOHI, VT, IsAfterLegalization)) {
5971       SDValue LoHi =
5972           DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), X, Y);
5973       return SDValue(LoHi.getNode(), 1);
5974     }
5975     return SDValue();
5976   };
5977 
5978   SDValue Q = GetMULHS(N0, MagicFactor);
5979   if (!Q)
5980     return SDValue();
5981 
5982   Created.push_back(Q.getNode());
5983 
5984   // (Optionally) Add/subtract the numerator using Factor.
5985   Factor = DAG.getNode(ISD::MUL, dl, VT, N0, Factor);
5986   Created.push_back(Factor.getNode());
5987   Q = DAG.getNode(ISD::ADD, dl, VT, Q, Factor);
5988   Created.push_back(Q.getNode());
5989 
5990   // Shift right algebraic by shift value.
5991   Q = DAG.getNode(ISD::SRA, dl, VT, Q, Shift);
5992   Created.push_back(Q.getNode());
5993 
5994   // Extract the sign bit, mask it and add it to the quotient.
5995   SDValue SignShift = DAG.getConstant(EltBits - 1, dl, ShVT);
5996   SDValue T = DAG.getNode(ISD::SRL, dl, VT, Q, SignShift);
5997   Created.push_back(T.getNode());
5998   T = DAG.getNode(ISD::AND, dl, VT, T, ShiftMask);
5999   Created.push_back(T.getNode());
6000   return DAG.getNode(ISD::ADD, dl, VT, Q, T);
6001 }
6002 
6003 /// Given an ISD::UDIV node expressing a divide by constant,
6004 /// return a DAG expression to select that will generate the same value by
6005 /// multiplying by a magic number.
6006 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N,SelectionDAG & DAG,bool IsAfterLegalization,SmallVectorImpl<SDNode * > & Created) const6007 SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
6008                                   bool IsAfterLegalization,
6009                                   SmallVectorImpl<SDNode *> &Created) const {
6010   SDLoc dl(N);
6011   EVT VT = N->getValueType(0);
6012   EVT SVT = VT.getScalarType();
6013   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
6014   EVT ShSVT = ShVT.getScalarType();
6015   unsigned EltBits = VT.getScalarSizeInBits();
6016   EVT MulVT;
6017 
6018   // Check to see if we can do this.
6019   // FIXME: We should be more aggressive here.
6020   if (!isTypeLegal(VT)) {
6021     // Limit this to simple scalars for now.
6022     if (VT.isVector() || !VT.isSimple())
6023       return SDValue();
6024 
6025     // If this type will be promoted to a large enough type with a legal
6026     // multiply operation, we can go ahead and do this transform.
6027     if (getTypeAction(VT.getSimpleVT()) != TypePromoteInteger)
6028       return SDValue();
6029 
6030     MulVT = getTypeToTransformTo(*DAG.getContext(), VT);
6031     if (MulVT.getSizeInBits() < (2 * EltBits) ||
6032         !isOperationLegal(ISD::MUL, MulVT))
6033       return SDValue();
6034   }
6035 
6036   SDValue N0 = N->getOperand(0);
6037   SDValue N1 = N->getOperand(1);
6038 
6039   // Try to use leading zeros of the dividend to reduce the multiplier and
6040   // avoid expensive fixups.
6041   // TODO: Support vectors.
6042   unsigned LeadingZeros = 0;
6043   if (!VT.isVector() && isa<ConstantSDNode>(N1)) {
6044     assert(!isOneConstant(N1) && "Unexpected divisor");
6045     LeadingZeros = DAG.computeKnownBits(N0).countMinLeadingZeros();
6046     // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros in
6047     // the dividend exceeds the leading zeros for the divisor.
6048     LeadingZeros =
6049         std::min(LeadingZeros,
6050                  cast<ConstantSDNode>(N1)->getAPIntValue().countLeadingZeros());
6051   }
6052 
6053   bool UseNPQ = false, UsePreShift = false, UsePostShift = false;
6054   SmallVector<SDValue, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
6055 
6056   auto BuildUDIVPattern = [&](ConstantSDNode *C) {
6057     if (C->isZero())
6058       return false;
6059     const APInt& Divisor = C->getAPIntValue();
6060 
6061     SDValue PreShift, MagicFactor, NPQFactor, PostShift;
6062 
6063     // Magic algorithm doesn't work for division by 1. We need to emit a select
6064     // at the end.
6065     if (Divisor.isOne()) {
6066       PreShift = PostShift = DAG.getUNDEF(ShSVT);
6067       MagicFactor = NPQFactor = DAG.getUNDEF(SVT);
6068     } else {
6069       UnsignedDivisionByConstantInfo magics =
6070           UnsignedDivisionByConstantInfo::get(Divisor, LeadingZeros);
6071 
6072       MagicFactor = DAG.getConstant(magics.Magic, dl, SVT);
6073 
6074       assert(magics.PreShift < Divisor.getBitWidth() &&
6075              "We shouldn't generate an undefined shift!");
6076       assert(magics.PostShift < Divisor.getBitWidth() &&
6077              "We shouldn't generate an undefined shift!");
6078       assert((!magics.IsAdd || magics.PreShift == 0) &&
6079              "Unexpected pre-shift");
6080       PreShift = DAG.getConstant(magics.PreShift, dl, ShSVT);
6081       PostShift = DAG.getConstant(magics.PostShift, dl, ShSVT);
6082       NPQFactor = DAG.getConstant(
6083           magics.IsAdd ? APInt::getOneBitSet(EltBits, EltBits - 1)
6084                        : APInt::getZero(EltBits),
6085           dl, SVT);
6086       UseNPQ |= magics.IsAdd;
6087       UsePreShift |= magics.PreShift != 0;
6088       UsePostShift |= magics.PostShift != 0;
6089     }
6090 
6091     PreShifts.push_back(PreShift);
6092     MagicFactors.push_back(MagicFactor);
6093     NPQFactors.push_back(NPQFactor);
6094     PostShifts.push_back(PostShift);
6095     return true;
6096   };
6097 
6098   // Collect the shifts/magic values from each element.
6099   if (!ISD::matchUnaryPredicate(N1, BuildUDIVPattern))
6100     return SDValue();
6101 
6102   SDValue PreShift, PostShift, MagicFactor, NPQFactor;
6103   if (N1.getOpcode() == ISD::BUILD_VECTOR) {
6104     PreShift = DAG.getBuildVector(ShVT, dl, PreShifts);
6105     MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors);
6106     NPQFactor = DAG.getBuildVector(VT, dl, NPQFactors);
6107     PostShift = DAG.getBuildVector(ShVT, dl, PostShifts);
6108   } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
6109     assert(PreShifts.size() == 1 && MagicFactors.size() == 1 &&
6110            NPQFactors.size() == 1 && PostShifts.size() == 1 &&
6111            "Expected matchUnaryPredicate to return one for scalable vectors");
6112     PreShift = DAG.getSplatVector(ShVT, dl, PreShifts[0]);
6113     MagicFactor = DAG.getSplatVector(VT, dl, MagicFactors[0]);
6114     NPQFactor = DAG.getSplatVector(VT, dl, NPQFactors[0]);
6115     PostShift = DAG.getSplatVector(ShVT, dl, PostShifts[0]);
6116   } else {
6117     assert(isa<ConstantSDNode>(N1) && "Expected a constant");
6118     PreShift = PreShifts[0];
6119     MagicFactor = MagicFactors[0];
6120     PostShift = PostShifts[0];
6121   }
6122 
6123   SDValue Q = N0;
6124   if (UsePreShift) {
6125     Q = DAG.getNode(ISD::SRL, dl, VT, Q, PreShift);
6126     Created.push_back(Q.getNode());
6127   }
6128 
6129   // FIXME: We should support doing a MUL in a wider type.
6130   auto GetMULHU = [&](SDValue X, SDValue Y) {
6131     // If the type isn't legal, use a wider mul of the the type calculated
6132     // earlier.
6133     if (!isTypeLegal(VT)) {
6134       X = DAG.getNode(ISD::ZERO_EXTEND, dl, MulVT, X);
6135       Y = DAG.getNode(ISD::ZERO_EXTEND, dl, MulVT, Y);
6136       Y = DAG.getNode(ISD::MUL, dl, MulVT, X, Y);
6137       Y = DAG.getNode(ISD::SRL, dl, MulVT, Y,
6138                       DAG.getShiftAmountConstant(EltBits, MulVT, dl));
6139       return DAG.getNode(ISD::TRUNCATE, dl, VT, Y);
6140     }
6141 
6142     if (isOperationLegalOrCustom(ISD::MULHU, VT, IsAfterLegalization))
6143       return DAG.getNode(ISD::MULHU, dl, VT, X, Y);
6144     if (isOperationLegalOrCustom(ISD::UMUL_LOHI, VT, IsAfterLegalization)) {
6145       SDValue LoHi =
6146           DAG.getNode(ISD::UMUL_LOHI, dl, DAG.getVTList(VT, VT), X, Y);
6147       return SDValue(LoHi.getNode(), 1);
6148     }
6149     return SDValue(); // No mulhu or equivalent
6150   };
6151 
6152   // Multiply the numerator (operand 0) by the magic value.
6153   Q = GetMULHU(Q, MagicFactor);
6154   if (!Q)
6155     return SDValue();
6156 
6157   Created.push_back(Q.getNode());
6158 
6159   if (UseNPQ) {
6160     SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N0, Q);
6161     Created.push_back(NPQ.getNode());
6162 
6163     // For vectors we might have a mix of non-NPQ/NPQ paths, so use
6164     // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
6165     if (VT.isVector())
6166       NPQ = GetMULHU(NPQ, NPQFactor);
6167     else
6168       NPQ = DAG.getNode(ISD::SRL, dl, VT, NPQ, DAG.getConstant(1, dl, ShVT));
6169 
6170     Created.push_back(NPQ.getNode());
6171 
6172     Q = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q);
6173     Created.push_back(Q.getNode());
6174   }
6175 
6176   if (UsePostShift) {
6177     Q = DAG.getNode(ISD::SRL, dl, VT, Q, PostShift);
6178     Created.push_back(Q.getNode());
6179   }
6180 
6181   EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6182 
6183   SDValue One = DAG.getConstant(1, dl, VT);
6184   SDValue IsOne = DAG.getSetCC(dl, SetCCVT, N1, One, ISD::SETEQ);
6185   return DAG.getSelect(dl, VT, IsOne, N0, Q);
6186 }
6187 
6188 /// If all values in Values that *don't* match the predicate are same 'splat'
6189 /// value, then replace all values with that splat value.
6190 /// Else, if AlternativeReplacement was provided, then replace all values that
6191 /// do match predicate with AlternativeReplacement value.
6192 static void
turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values,std::function<bool (SDValue)> Predicate,SDValue AlternativeReplacement=SDValue ())6193 turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values,
6194                           std::function<bool(SDValue)> Predicate,
6195                           SDValue AlternativeReplacement = SDValue()) {
6196   SDValue Replacement;
6197   // Is there a value for which the Predicate does *NOT* match? What is it?
6198   auto SplatValue = llvm::find_if_not(Values, Predicate);
6199   if (SplatValue != Values.end()) {
6200     // Does Values consist only of SplatValue's and values matching Predicate?
6201     if (llvm::all_of(Values, [Predicate, SplatValue](SDValue Value) {
6202           return Value == *SplatValue || Predicate(Value);
6203         })) // Then we shall replace values matching predicate with SplatValue.
6204       Replacement = *SplatValue;
6205   }
6206   if (!Replacement) {
6207     // Oops, we did not find the "baseline" splat value.
6208     if (!AlternativeReplacement)
6209       return; // Nothing to do.
6210     // Let's replace with provided value then.
6211     Replacement = AlternativeReplacement;
6212   }
6213   std::replace_if(Values.begin(), Values.end(), Predicate, Replacement);
6214 }
6215 
6216 /// Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE
6217 /// where the divisor is constant and the comparison target is zero,
6218 /// return a DAG expression that will generate the same comparison result
6219 /// using only multiplications, additions and shifts/rotations.
6220 /// Ref: "Hacker's Delight" 10-17.
buildUREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const6221 SDValue TargetLowering::buildUREMEqFold(EVT SETCCVT, SDValue REMNode,
6222                                         SDValue CompTargetNode,
6223                                         ISD::CondCode Cond,
6224                                         DAGCombinerInfo &DCI,
6225                                         const SDLoc &DL) const {
6226   SmallVector<SDNode *, 5> Built;
6227   if (SDValue Folded = prepareUREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond,
6228                                          DCI, DL, Built)) {
6229     for (SDNode *N : Built)
6230       DCI.AddToWorklist(N);
6231     return Folded;
6232   }
6233 
6234   return SDValue();
6235 }
6236 
6237 SDValue
prepareUREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL,SmallVectorImpl<SDNode * > & Created) const6238 TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
6239                                   SDValue CompTargetNode, ISD::CondCode Cond,
6240                                   DAGCombinerInfo &DCI, const SDLoc &DL,
6241                                   SmallVectorImpl<SDNode *> &Created) const {
6242   // fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q)
6243   // - D must be constant, with D = D0 * 2^K where D0 is odd
6244   // - P is the multiplicative inverse of D0 modulo 2^W
6245   // - Q = floor(((2^W) - 1) / D)
6246   // where W is the width of the common type of N and D.
6247   assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
6248          "Only applicable for (in)equality comparisons.");
6249 
6250   SelectionDAG &DAG = DCI.DAG;
6251 
6252   EVT VT = REMNode.getValueType();
6253   EVT SVT = VT.getScalarType();
6254   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout(), !DCI.isBeforeLegalize());
6255   EVT ShSVT = ShVT.getScalarType();
6256 
6257   // If MUL is unavailable, we cannot proceed in any case.
6258   if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::MUL, VT))
6259     return SDValue();
6260 
6261   bool ComparingWithAllZeros = true;
6262   bool AllComparisonsWithNonZerosAreTautological = true;
6263   bool HadTautologicalLanes = false;
6264   bool AllLanesAreTautological = true;
6265   bool HadEvenDivisor = false;
6266   bool AllDivisorsArePowerOfTwo = true;
6267   bool HadTautologicalInvertedLanes = false;
6268   SmallVector<SDValue, 16> PAmts, KAmts, QAmts, IAmts;
6269 
6270   auto BuildUREMPattern = [&](ConstantSDNode *CDiv, ConstantSDNode *CCmp) {
6271     // Division by 0 is UB. Leave it to be constant-folded elsewhere.
6272     if (CDiv->isZero())
6273       return false;
6274 
6275     const APInt &D = CDiv->getAPIntValue();
6276     const APInt &Cmp = CCmp->getAPIntValue();
6277 
6278     ComparingWithAllZeros &= Cmp.isZero();
6279 
6280     // x u% C1` is *always* less than C1. So given `x u% C1 == C2`,
6281     // if C2 is not less than C1, the comparison is always false.
6282     // But we will only be able to produce the comparison that will give the
6283     // opposive tautological answer. So this lane would need to be fixed up.
6284     bool TautologicalInvertedLane = D.ule(Cmp);
6285     HadTautologicalInvertedLanes |= TautologicalInvertedLane;
6286 
6287     // If all lanes are tautological (either all divisors are ones, or divisor
6288     // is not greater than the constant we are comparing with),
6289     // we will prefer to avoid the fold.
6290     bool TautologicalLane = D.isOne() || TautologicalInvertedLane;
6291     HadTautologicalLanes |= TautologicalLane;
6292     AllLanesAreTautological &= TautologicalLane;
6293 
6294     // If we are comparing with non-zero, we need'll need  to subtract said
6295     // comparison value from the LHS. But there is no point in doing that if
6296     // every lane where we are comparing with non-zero is tautological..
6297     if (!Cmp.isZero())
6298       AllComparisonsWithNonZerosAreTautological &= TautologicalLane;
6299 
6300     // Decompose D into D0 * 2^K
6301     unsigned K = D.countTrailingZeros();
6302     assert((!D.isOne() || (K == 0)) && "For divisor '1' we won't rotate.");
6303     APInt D0 = D.lshr(K);
6304 
6305     // D is even if it has trailing zeros.
6306     HadEvenDivisor |= (K != 0);
6307     // D is a power-of-two if D0 is one.
6308     // If all divisors are power-of-two, we will prefer to avoid the fold.
6309     AllDivisorsArePowerOfTwo &= D0.isOne();
6310 
6311     // P = inv(D0, 2^W)
6312     // 2^W requires W + 1 bits, so we have to extend and then truncate.
6313     unsigned W = D.getBitWidth();
6314     APInt P = D0.zext(W + 1)
6315                   .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
6316                   .trunc(W);
6317     assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
6318     assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
6319 
6320     // Q = floor((2^W - 1) u/ D)
6321     // R = ((2^W - 1) u% D)
6322     APInt Q, R;
6323     APInt::udivrem(APInt::getAllOnes(W), D, Q, R);
6324 
6325     // If we are comparing with zero, then that comparison constant is okay,
6326     // else it may need to be one less than that.
6327     if (Cmp.ugt(R))
6328       Q -= 1;
6329 
6330     assert(APInt::getAllOnes(ShSVT.getSizeInBits()).ugt(K) &&
6331            "We are expecting that K is always less than all-ones for ShSVT");
6332 
6333     // If the lane is tautological the result can be constant-folded.
6334     if (TautologicalLane) {
6335       // Set P and K amount to a bogus values so we can try to splat them.
6336       P = 0;
6337       K = -1;
6338       // And ensure that comparison constant is tautological,
6339       // it will always compare true/false.
6340       Q = -1;
6341     }
6342 
6343     PAmts.push_back(DAG.getConstant(P, DL, SVT));
6344     KAmts.push_back(
6345         DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT));
6346     QAmts.push_back(DAG.getConstant(Q, DL, SVT));
6347     return true;
6348   };
6349 
6350   SDValue N = REMNode.getOperand(0);
6351   SDValue D = REMNode.getOperand(1);
6352 
6353   // Collect the values from each element.
6354   if (!ISD::matchBinaryPredicate(D, CompTargetNode, BuildUREMPattern))
6355     return SDValue();
6356 
6357   // If all lanes are tautological, the result can be constant-folded.
6358   if (AllLanesAreTautological)
6359     return SDValue();
6360 
6361   // If this is a urem by a powers-of-two, avoid the fold since it can be
6362   // best implemented as a bit test.
6363   if (AllDivisorsArePowerOfTwo)
6364     return SDValue();
6365 
6366   SDValue PVal, KVal, QVal;
6367   if (D.getOpcode() == ISD::BUILD_VECTOR) {
6368     if (HadTautologicalLanes) {
6369       // Try to turn PAmts into a splat, since we don't care about the values
6370       // that are currently '0'. If we can't, just keep '0'`s.
6371       turnVectorIntoSplatVector(PAmts, isNullConstant);
6372       // Try to turn KAmts into a splat, since we don't care about the values
6373       // that are currently '-1'. If we can't, change them to '0'`s.
6374       turnVectorIntoSplatVector(KAmts, isAllOnesConstant,
6375                                 DAG.getConstant(0, DL, ShSVT));
6376     }
6377 
6378     PVal = DAG.getBuildVector(VT, DL, PAmts);
6379     KVal = DAG.getBuildVector(ShVT, DL, KAmts);
6380     QVal = DAG.getBuildVector(VT, DL, QAmts);
6381   } else if (D.getOpcode() == ISD::SPLAT_VECTOR) {
6382     assert(PAmts.size() == 1 && KAmts.size() == 1 && QAmts.size() == 1 &&
6383            "Expected matchBinaryPredicate to return one element for "
6384            "SPLAT_VECTORs");
6385     PVal = DAG.getSplatVector(VT, DL, PAmts[0]);
6386     KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]);
6387     QVal = DAG.getSplatVector(VT, DL, QAmts[0]);
6388   } else {
6389     PVal = PAmts[0];
6390     KVal = KAmts[0];
6391     QVal = QAmts[0];
6392   }
6393 
6394   if (!ComparingWithAllZeros && !AllComparisonsWithNonZerosAreTautological) {
6395     if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::SUB, VT))
6396       return SDValue(); // FIXME: Could/should use `ISD::ADD`?
6397     assert(CompTargetNode.getValueType() == N.getValueType() &&
6398            "Expecting that the types on LHS and RHS of comparisons match.");
6399     N = DAG.getNode(ISD::SUB, DL, VT, N, CompTargetNode);
6400   }
6401 
6402   // (mul N, P)
6403   SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal);
6404   Created.push_back(Op0.getNode());
6405 
6406   // Rotate right only if any divisor was even. We avoid rotates for all-odd
6407   // divisors as a performance improvement, since rotating by 0 is a no-op.
6408   if (HadEvenDivisor) {
6409     // We need ROTR to do this.
6410     if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::ROTR, VT))
6411       return SDValue();
6412     // UREM: (rotr (mul N, P), K)
6413     Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal);
6414     Created.push_back(Op0.getNode());
6415   }
6416 
6417   // UREM: (setule/setugt (rotr (mul N, P), K), Q)
6418   SDValue NewCC =
6419       DAG.getSetCC(DL, SETCCVT, Op0, QVal,
6420                    ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
6421   if (!HadTautologicalInvertedLanes)
6422     return NewCC;
6423 
6424   // If any lanes previously compared always-false, the NewCC will give
6425   // always-true result for them, so we need to fixup those lanes.
6426   // Or the other way around for inequality predicate.
6427   assert(VT.isVector() && "Can/should only get here for vectors.");
6428   Created.push_back(NewCC.getNode());
6429 
6430   // x u% C1` is *always* less than C1. So given `x u% C1 == C2`,
6431   // if C2 is not less than C1, the comparison is always false.
6432   // But we have produced the comparison that will give the
6433   // opposive tautological answer. So these lanes would need to be fixed up.
6434   SDValue TautologicalInvertedChannels =
6435       DAG.getSetCC(DL, SETCCVT, D, CompTargetNode, ISD::SETULE);
6436   Created.push_back(TautologicalInvertedChannels.getNode());
6437 
6438   // NOTE: we avoid letting illegal types through even if we're before legalize
6439   // ops – legalization has a hard time producing good code for this.
6440   if (isOperationLegalOrCustom(ISD::VSELECT, SETCCVT)) {
6441     // If we have a vector select, let's replace the comparison results in the
6442     // affected lanes with the correct tautological result.
6443     SDValue Replacement = DAG.getBoolConstant(Cond == ISD::SETEQ ? false : true,
6444                                               DL, SETCCVT, SETCCVT);
6445     return DAG.getNode(ISD::VSELECT, DL, SETCCVT, TautologicalInvertedChannels,
6446                        Replacement, NewCC);
6447   }
6448 
6449   // Else, we can just invert the comparison result in the appropriate lanes.
6450   //
6451   // NOTE: see the note above VSELECT above.
6452   if (isOperationLegalOrCustom(ISD::XOR, SETCCVT))
6453     return DAG.getNode(ISD::XOR, DL, SETCCVT, NewCC,
6454                        TautologicalInvertedChannels);
6455 
6456   return SDValue(); // Don't know how to lower.
6457 }
6458 
6459 /// Given an ISD::SREM used only by an ISD::SETEQ or ISD::SETNE
6460 /// where the divisor is constant and the comparison target is zero,
6461 /// return a DAG expression that will generate the same comparison result
6462 /// using only multiplications, additions and shifts/rotations.
6463 /// Ref: "Hacker's Delight" 10-17.
buildSREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const6464 SDValue TargetLowering::buildSREMEqFold(EVT SETCCVT, SDValue REMNode,
6465                                         SDValue CompTargetNode,
6466                                         ISD::CondCode Cond,
6467                                         DAGCombinerInfo &DCI,
6468                                         const SDLoc &DL) const {
6469   SmallVector<SDNode *, 7> Built;
6470   if (SDValue Folded = prepareSREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond,
6471                                          DCI, DL, Built)) {
6472     assert(Built.size() <= 7 && "Max size prediction failed.");
6473     for (SDNode *N : Built)
6474       DCI.AddToWorklist(N);
6475     return Folded;
6476   }
6477 
6478   return SDValue();
6479 }
6480 
6481 SDValue
prepareSREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL,SmallVectorImpl<SDNode * > & Created) const6482 TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
6483                                   SDValue CompTargetNode, ISD::CondCode Cond,
6484                                   DAGCombinerInfo &DCI, const SDLoc &DL,
6485                                   SmallVectorImpl<SDNode *> &Created) const {
6486   // Fold:
6487   //   (seteq/ne (srem N, D), 0)
6488   // To:
6489   //   (setule/ugt (rotr (add (mul N, P), A), K), Q)
6490   //
6491   // - D must be constant, with D = D0 * 2^K where D0 is odd
6492   // - P is the multiplicative inverse of D0 modulo 2^W
6493   // - A = bitwiseand(floor((2^(W - 1) - 1) / D0), (-(2^k)))
6494   // - Q = floor((2 * A) / (2^K))
6495   // where W is the width of the common type of N and D.
6496   assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
6497          "Only applicable for (in)equality comparisons.");
6498 
6499   SelectionDAG &DAG = DCI.DAG;
6500 
6501   EVT VT = REMNode.getValueType();
6502   EVT SVT = VT.getScalarType();
6503   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout(), !DCI.isBeforeLegalize());
6504   EVT ShSVT = ShVT.getScalarType();
6505 
6506   // If we are after ops legalization, and MUL is unavailable, we can not
6507   // proceed.
6508   if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::MUL, VT))
6509     return SDValue();
6510 
6511   // TODO: Could support comparing with non-zero too.
6512   ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode);
6513   if (!CompTarget || !CompTarget->isZero())
6514     return SDValue();
6515 
6516   bool HadIntMinDivisor = false;
6517   bool HadOneDivisor = false;
6518   bool AllDivisorsAreOnes = true;
6519   bool HadEvenDivisor = false;
6520   bool NeedToApplyOffset = false;
6521   bool AllDivisorsArePowerOfTwo = true;
6522   SmallVector<SDValue, 16> PAmts, AAmts, KAmts, QAmts;
6523 
6524   auto BuildSREMPattern = [&](ConstantSDNode *C) {
6525     // Division by 0 is UB. Leave it to be constant-folded elsewhere.
6526     if (C->isZero())
6527       return false;
6528 
6529     // FIXME: we don't fold `rem %X, -C` to `rem %X, C` in DAGCombine.
6530 
6531     // WARNING: this fold is only valid for positive divisors!
6532     APInt D = C->getAPIntValue();
6533     if (D.isNegative())
6534       D.negate(); //  `rem %X, -C` is equivalent to `rem %X, C`
6535 
6536     HadIntMinDivisor |= D.isMinSignedValue();
6537 
6538     // If all divisors are ones, we will prefer to avoid the fold.
6539     HadOneDivisor |= D.isOne();
6540     AllDivisorsAreOnes &= D.isOne();
6541 
6542     // Decompose D into D0 * 2^K
6543     unsigned K = D.countTrailingZeros();
6544     assert((!D.isOne() || (K == 0)) && "For divisor '1' we won't rotate.");
6545     APInt D0 = D.lshr(K);
6546 
6547     if (!D.isMinSignedValue()) {
6548       // D is even if it has trailing zeros; unless it's INT_MIN, in which case
6549       // we don't care about this lane in this fold, we'll special-handle it.
6550       HadEvenDivisor |= (K != 0);
6551     }
6552 
6553     // D is a power-of-two if D0 is one. This includes INT_MIN.
6554     // If all divisors are power-of-two, we will prefer to avoid the fold.
6555     AllDivisorsArePowerOfTwo &= D0.isOne();
6556 
6557     // P = inv(D0, 2^W)
6558     // 2^W requires W + 1 bits, so we have to extend and then truncate.
6559     unsigned W = D.getBitWidth();
6560     APInt P = D0.zext(W + 1)
6561                   .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
6562                   .trunc(W);
6563     assert(!P.isZero() && "No multiplicative inverse!"); // unreachable
6564     assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
6565 
6566     // A = floor((2^(W - 1) - 1) / D0) & -2^K
6567     APInt A = APInt::getSignedMaxValue(W).udiv(D0);
6568     A.clearLowBits(K);
6569 
6570     if (!D.isMinSignedValue()) {
6571       // If divisor INT_MIN, then we don't care about this lane in this fold,
6572       // we'll special-handle it.
6573       NeedToApplyOffset |= A != 0;
6574     }
6575 
6576     // Q = floor((2 * A) / (2^K))
6577     APInt Q = (2 * A).udiv(APInt::getOneBitSet(W, K));
6578 
6579     assert(APInt::getAllOnes(SVT.getSizeInBits()).ugt(A) &&
6580            "We are expecting that A is always less than all-ones for SVT");
6581     assert(APInt::getAllOnes(ShSVT.getSizeInBits()).ugt(K) &&
6582            "We are expecting that K is always less than all-ones for ShSVT");
6583 
6584     // If the divisor is 1 the result can be constant-folded. Likewise, we
6585     // don't care about INT_MIN lanes, those can be set to undef if appropriate.
6586     if (D.isOne()) {
6587       // Set P, A and K to a bogus values so we can try to splat them.
6588       P = 0;
6589       A = -1;
6590       K = -1;
6591 
6592       // x ?% 1 == 0  <-->  true  <-->  x u<= -1
6593       Q = -1;
6594     }
6595 
6596     PAmts.push_back(DAG.getConstant(P, DL, SVT));
6597     AAmts.push_back(DAG.getConstant(A, DL, SVT));
6598     KAmts.push_back(
6599         DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT));
6600     QAmts.push_back(DAG.getConstant(Q, DL, SVT));
6601     return true;
6602   };
6603 
6604   SDValue N = REMNode.getOperand(0);
6605   SDValue D = REMNode.getOperand(1);
6606 
6607   // Collect the values from each element.
6608   if (!ISD::matchUnaryPredicate(D, BuildSREMPattern))
6609     return SDValue();
6610 
6611   // If this is a srem by a one, avoid the fold since it can be constant-folded.
6612   if (AllDivisorsAreOnes)
6613     return SDValue();
6614 
6615   // If this is a srem by a powers-of-two (including INT_MIN), avoid the fold
6616   // since it can be best implemented as a bit test.
6617   if (AllDivisorsArePowerOfTwo)
6618     return SDValue();
6619 
6620   SDValue PVal, AVal, KVal, QVal;
6621   if (D.getOpcode() == ISD::BUILD_VECTOR) {
6622     if (HadOneDivisor) {
6623       // Try to turn PAmts into a splat, since we don't care about the values
6624       // that are currently '0'. If we can't, just keep '0'`s.
6625       turnVectorIntoSplatVector(PAmts, isNullConstant);
6626       // Try to turn AAmts into a splat, since we don't care about the
6627       // values that are currently '-1'. If we can't, change them to '0'`s.
6628       turnVectorIntoSplatVector(AAmts, isAllOnesConstant,
6629                                 DAG.getConstant(0, DL, SVT));
6630       // Try to turn KAmts into a splat, since we don't care about the values
6631       // that are currently '-1'. If we can't, change them to '0'`s.
6632       turnVectorIntoSplatVector(KAmts, isAllOnesConstant,
6633                                 DAG.getConstant(0, DL, ShSVT));
6634     }
6635 
6636     PVal = DAG.getBuildVector(VT, DL, PAmts);
6637     AVal = DAG.getBuildVector(VT, DL, AAmts);
6638     KVal = DAG.getBuildVector(ShVT, DL, KAmts);
6639     QVal = DAG.getBuildVector(VT, DL, QAmts);
6640   } else if (D.getOpcode() == ISD::SPLAT_VECTOR) {
6641     assert(PAmts.size() == 1 && AAmts.size() == 1 && KAmts.size() == 1 &&
6642            QAmts.size() == 1 &&
6643            "Expected matchUnaryPredicate to return one element for scalable "
6644            "vectors");
6645     PVal = DAG.getSplatVector(VT, DL, PAmts[0]);
6646     AVal = DAG.getSplatVector(VT, DL, AAmts[0]);
6647     KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]);
6648     QVal = DAG.getSplatVector(VT, DL, QAmts[0]);
6649   } else {
6650     assert(isa<ConstantSDNode>(D) && "Expected a constant");
6651     PVal = PAmts[0];
6652     AVal = AAmts[0];
6653     KVal = KAmts[0];
6654     QVal = QAmts[0];
6655   }
6656 
6657   // (mul N, P)
6658   SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal);
6659   Created.push_back(Op0.getNode());
6660 
6661   if (NeedToApplyOffset) {
6662     // We need ADD to do this.
6663     if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::ADD, VT))
6664       return SDValue();
6665 
6666     // (add (mul N, P), A)
6667     Op0 = DAG.getNode(ISD::ADD, DL, VT, Op0, AVal);
6668     Created.push_back(Op0.getNode());
6669   }
6670 
6671   // Rotate right only if any divisor was even. We avoid rotates for all-odd
6672   // divisors as a performance improvement, since rotating by 0 is a no-op.
6673   if (HadEvenDivisor) {
6674     // We need ROTR to do this.
6675     if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::ROTR, VT))
6676       return SDValue();
6677     // SREM: (rotr (add (mul N, P), A), K)
6678     Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal);
6679     Created.push_back(Op0.getNode());
6680   }
6681 
6682   // SREM: (setule/setugt (rotr (add (mul N, P), A), K), Q)
6683   SDValue Fold =
6684       DAG.getSetCC(DL, SETCCVT, Op0, QVal,
6685                    ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
6686 
6687   // If we didn't have lanes with INT_MIN divisor, then we're done.
6688   if (!HadIntMinDivisor)
6689     return Fold;
6690 
6691   // That fold is only valid for positive divisors. Which effectively means,
6692   // it is invalid for INT_MIN divisors. So if we have such a lane,
6693   // we must fix-up results for said lanes.
6694   assert(VT.isVector() && "Can/should only get here for vectors.");
6695 
6696   // NOTE: we avoid letting illegal types through even if we're before legalize
6697   // ops – legalization has a hard time producing good code for the code that
6698   // follows.
6699   if (!isOperationLegalOrCustom(ISD::SETEQ, VT) ||
6700       !isOperationLegalOrCustom(ISD::AND, VT) ||
6701       !isOperationLegalOrCustom(Cond, VT) ||
6702       !isOperationLegalOrCustom(ISD::VSELECT, SETCCVT))
6703     return SDValue();
6704 
6705   Created.push_back(Fold.getNode());
6706 
6707   SDValue IntMin = DAG.getConstant(
6708       APInt::getSignedMinValue(SVT.getScalarSizeInBits()), DL, VT);
6709   SDValue IntMax = DAG.getConstant(
6710       APInt::getSignedMaxValue(SVT.getScalarSizeInBits()), DL, VT);
6711   SDValue Zero =
6712       DAG.getConstant(APInt::getZero(SVT.getScalarSizeInBits()), DL, VT);
6713 
6714   // Which lanes had INT_MIN divisors? Divisor is constant, so const-folded.
6715   SDValue DivisorIsIntMin = DAG.getSetCC(DL, SETCCVT, D, IntMin, ISD::SETEQ);
6716   Created.push_back(DivisorIsIntMin.getNode());
6717 
6718   // (N s% INT_MIN) ==/!= 0  <-->  (N & INT_MAX) ==/!= 0
6719   SDValue Masked = DAG.getNode(ISD::AND, DL, VT, N, IntMax);
6720   Created.push_back(Masked.getNode());
6721   SDValue MaskedIsZero = DAG.getSetCC(DL, SETCCVT, Masked, Zero, Cond);
6722   Created.push_back(MaskedIsZero.getNode());
6723 
6724   // To produce final result we need to blend 2 vectors: 'SetCC' and
6725   // 'MaskedIsZero'. If the divisor for channel was *NOT* INT_MIN, we pick
6726   // from 'Fold', else pick from 'MaskedIsZero'. Since 'DivisorIsIntMin' is
6727   // constant-folded, select can get lowered to a shuffle with constant mask.
6728   SDValue Blended = DAG.getNode(ISD::VSELECT, DL, SETCCVT, DivisorIsIntMin,
6729                                 MaskedIsZero, Fold);
6730 
6731   return Blended;
6732 }
6733 
6734 bool TargetLowering::
verifyReturnAddressArgumentIsConstant(SDValue Op,SelectionDAG & DAG) const6735 verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
6736   if (!isa<ConstantSDNode>(Op.getOperand(0))) {
6737     DAG.getContext()->emitError("argument to '__builtin_return_address' must "
6738                                 "be a constant integer");
6739     return true;
6740   }
6741 
6742   return false;
6743 }
6744 
getSqrtInputTest(SDValue Op,SelectionDAG & DAG,const DenormalMode & Mode) const6745 SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
6746                                          const DenormalMode &Mode) const {
6747   SDLoc DL(Op);
6748   EVT VT = Op.getValueType();
6749   EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6750   SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
6751   // Testing it with denormal inputs to avoid wrong estimate.
6752   if (Mode.Input == DenormalMode::IEEE) {
6753     // This is specifically a check for the handling of denormal inputs,
6754     // not the result.
6755 
6756     // Test = fabs(X) < SmallestNormal
6757     const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
6758     APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
6759     SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
6760     SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
6761     return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
6762   }
6763   // Test = X == 0.0
6764   return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
6765 }
6766 
getNegatedExpression(SDValue Op,SelectionDAG & DAG,bool LegalOps,bool OptForSize,NegatibleCost & Cost,unsigned Depth) const6767 SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
6768                                              bool LegalOps, bool OptForSize,
6769                                              NegatibleCost &Cost,
6770                                              unsigned Depth) const {
6771   // fneg is removable even if it has multiple uses.
6772   if (Op.getOpcode() == ISD::FNEG) {
6773     Cost = NegatibleCost::Cheaper;
6774     return Op.getOperand(0);
6775   }
6776 
6777   // Don't recurse exponentially.
6778   if (Depth > SelectionDAG::MaxRecursionDepth)
6779     return SDValue();
6780 
6781   // Pre-increment recursion depth for use in recursive calls.
6782   ++Depth;
6783   const SDNodeFlags Flags = Op->getFlags();
6784   const TargetOptions &Options = DAG.getTarget().Options;
6785   EVT VT = Op.getValueType();
6786   unsigned Opcode = Op.getOpcode();
6787 
6788   // Don't allow anything with multiple uses unless we know it is free.
6789   if (!Op.hasOneUse() && Opcode != ISD::ConstantFP) {
6790     bool IsFreeExtend = Opcode == ISD::FP_EXTEND &&
6791                         isFPExtFree(VT, Op.getOperand(0).getValueType());
6792     if (!IsFreeExtend)
6793       return SDValue();
6794   }
6795 
6796   auto RemoveDeadNode = [&](SDValue N) {
6797     if (N && N.getNode()->use_empty())
6798       DAG.RemoveDeadNode(N.getNode());
6799   };
6800 
6801   SDLoc DL(Op);
6802 
6803   // Because getNegatedExpression can delete nodes we need a handle to keep
6804   // temporary nodes alive in case the recursion manages to create an identical
6805   // node.
6806   std::list<HandleSDNode> Handles;
6807 
6808   switch (Opcode) {
6809   case ISD::ConstantFP: {
6810     // Don't invert constant FP values after legalization unless the target says
6811     // the negated constant is legal.
6812     bool IsOpLegal =
6813         isOperationLegal(ISD::ConstantFP, VT) ||
6814         isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT,
6815                      OptForSize);
6816 
6817     if (LegalOps && !IsOpLegal)
6818       break;
6819 
6820     APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
6821     V.changeSign();
6822     SDValue CFP = DAG.getConstantFP(V, DL, VT);
6823 
6824     // If we already have the use of the negated floating constant, it is free
6825     // to negate it even it has multiple uses.
6826     if (!Op.hasOneUse() && CFP.use_empty())
6827       break;
6828     Cost = NegatibleCost::Neutral;
6829     return CFP;
6830   }
6831   case ISD::BUILD_VECTOR: {
6832     // Only permit BUILD_VECTOR of constants.
6833     if (llvm::any_of(Op->op_values(), [&](SDValue N) {
6834           return !N.isUndef() && !isa<ConstantFPSDNode>(N);
6835         }))
6836       break;
6837 
6838     bool IsOpLegal =
6839         (isOperationLegal(ISD::ConstantFP, VT) &&
6840          isOperationLegal(ISD::BUILD_VECTOR, VT)) ||
6841         llvm::all_of(Op->op_values(), [&](SDValue N) {
6842           return N.isUndef() ||
6843                  isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT,
6844                               OptForSize);
6845         });
6846 
6847     if (LegalOps && !IsOpLegal)
6848       break;
6849 
6850     SmallVector<SDValue, 4> Ops;
6851     for (SDValue C : Op->op_values()) {
6852       if (C.isUndef()) {
6853         Ops.push_back(C);
6854         continue;
6855       }
6856       APFloat V = cast<ConstantFPSDNode>(C)->getValueAPF();
6857       V.changeSign();
6858       Ops.push_back(DAG.getConstantFP(V, DL, C.getValueType()));
6859     }
6860     Cost = NegatibleCost::Neutral;
6861     return DAG.getBuildVector(VT, DL, Ops);
6862   }
6863   case ISD::FADD: {
6864     if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
6865       break;
6866 
6867     // After operation legalization, it might not be legal to create new FSUBs.
6868     if (LegalOps && !isOperationLegalOrCustom(ISD::FSUB, VT))
6869       break;
6870     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
6871 
6872     // fold (fneg (fadd X, Y)) -> (fsub (fneg X), Y)
6873     NegatibleCost CostX = NegatibleCost::Expensive;
6874     SDValue NegX =
6875         getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
6876     // Prevent this node from being deleted by the next call.
6877     if (NegX)
6878       Handles.emplace_back(NegX);
6879 
6880     // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X)
6881     NegatibleCost CostY = NegatibleCost::Expensive;
6882     SDValue NegY =
6883         getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
6884 
6885     // We're done with the handles.
6886     Handles.clear();
6887 
6888     // Negate the X if its cost is less or equal than Y.
6889     if (NegX && (CostX <= CostY)) {
6890       Cost = CostX;
6891       SDValue N = DAG.getNode(ISD::FSUB, DL, VT, NegX, Y, Flags);
6892       if (NegY != N)
6893         RemoveDeadNode(NegY);
6894       return N;
6895     }
6896 
6897     // Negate the Y if it is not expensive.
6898     if (NegY) {
6899       Cost = CostY;
6900       SDValue N = DAG.getNode(ISD::FSUB, DL, VT, NegY, X, Flags);
6901       if (NegX != N)
6902         RemoveDeadNode(NegX);
6903       return N;
6904     }
6905     break;
6906   }
6907   case ISD::FSUB: {
6908     // We can't turn -(A-B) into B-A when we honor signed zeros.
6909     if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
6910       break;
6911 
6912     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
6913     // fold (fneg (fsub 0, Y)) -> Y
6914     if (ConstantFPSDNode *C = isConstOrConstSplatFP(X, /*AllowUndefs*/ true))
6915       if (C->isZero()) {
6916         Cost = NegatibleCost::Cheaper;
6917         return Y;
6918       }
6919 
6920     // fold (fneg (fsub X, Y)) -> (fsub Y, X)
6921     Cost = NegatibleCost::Neutral;
6922     return DAG.getNode(ISD::FSUB, DL, VT, Y, X, Flags);
6923   }
6924   case ISD::FMUL:
6925   case ISD::FDIV: {
6926     SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
6927 
6928     // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
6929     NegatibleCost CostX = NegatibleCost::Expensive;
6930     SDValue NegX =
6931         getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
6932     // Prevent this node from being deleted by the next call.
6933     if (NegX)
6934       Handles.emplace_back(NegX);
6935 
6936     // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y))
6937     NegatibleCost CostY = NegatibleCost::Expensive;
6938     SDValue NegY =
6939         getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
6940 
6941     // We're done with the handles.
6942     Handles.clear();
6943 
6944     // Negate the X if its cost is less or equal than Y.
6945     if (NegX && (CostX <= CostY)) {
6946       Cost = CostX;
6947       SDValue N = DAG.getNode(Opcode, DL, VT, NegX, Y, Flags);
6948       if (NegY != N)
6949         RemoveDeadNode(NegY);
6950       return N;
6951     }
6952 
6953     // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
6954     if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
6955       if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
6956         break;
6957 
6958     // Negate the Y if it is not expensive.
6959     if (NegY) {
6960       Cost = CostY;
6961       SDValue N = DAG.getNode(Opcode, DL, VT, X, NegY, Flags);
6962       if (NegX != N)
6963         RemoveDeadNode(NegX);
6964       return N;
6965     }
6966     break;
6967   }
6968   case ISD::FMA:
6969   case ISD::FMAD: {
6970     if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
6971       break;
6972 
6973     SDValue X = Op.getOperand(0), Y = Op.getOperand(1), Z = Op.getOperand(2);
6974     NegatibleCost CostZ = NegatibleCost::Expensive;
6975     SDValue NegZ =
6976         getNegatedExpression(Z, DAG, LegalOps, OptForSize, CostZ, Depth);
6977     // Give up if fail to negate the Z.
6978     if (!NegZ)
6979       break;
6980 
6981     // Prevent this node from being deleted by the next two calls.
6982     Handles.emplace_back(NegZ);
6983 
6984     // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
6985     NegatibleCost CostX = NegatibleCost::Expensive;
6986     SDValue NegX =
6987         getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
6988     // Prevent this node from being deleted by the next call.
6989     if (NegX)
6990       Handles.emplace_back(NegX);
6991 
6992     // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
6993     NegatibleCost CostY = NegatibleCost::Expensive;
6994     SDValue NegY =
6995         getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
6996 
6997     // We're done with the handles.
6998     Handles.clear();
6999 
7000     // Negate the X if its cost is less or equal than Y.
7001     if (NegX && (CostX <= CostY)) {
7002       Cost = std::min(CostX, CostZ);
7003       SDValue N = DAG.getNode(Opcode, DL, VT, NegX, Y, NegZ, Flags);
7004       if (NegY != N)
7005         RemoveDeadNode(NegY);
7006       return N;
7007     }
7008 
7009     // Negate the Y if it is not expensive.
7010     if (NegY) {
7011       Cost = std::min(CostY, CostZ);
7012       SDValue N = DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags);
7013       if (NegX != N)
7014         RemoveDeadNode(NegX);
7015       return N;
7016     }
7017     break;
7018   }
7019 
7020   case ISD::FP_EXTEND:
7021   case ISD::FSIN:
7022     if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
7023                                             OptForSize, Cost, Depth))
7024       return DAG.getNode(Opcode, DL, VT, NegV);
7025     break;
7026   case ISD::FP_ROUND:
7027     if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
7028                                             OptForSize, Cost, Depth))
7029       return DAG.getNode(ISD::FP_ROUND, DL, VT, NegV, Op.getOperand(1));
7030     break;
7031   case ISD::SELECT:
7032   case ISD::VSELECT: {
7033     // fold (fneg (select C, LHS, RHS)) -> (select C, (fneg LHS), (fneg RHS))
7034     // iff at least one cost is cheaper and the other is neutral/cheaper
7035     SDValue LHS = Op.getOperand(1);
7036     NegatibleCost CostLHS = NegatibleCost::Expensive;
7037     SDValue NegLHS =
7038         getNegatedExpression(LHS, DAG, LegalOps, OptForSize, CostLHS, Depth);
7039     if (!NegLHS || CostLHS > NegatibleCost::Neutral) {
7040       RemoveDeadNode(NegLHS);
7041       break;
7042     }
7043 
7044     // Prevent this node from being deleted by the next call.
7045     Handles.emplace_back(NegLHS);
7046 
7047     SDValue RHS = Op.getOperand(2);
7048     NegatibleCost CostRHS = NegatibleCost::Expensive;
7049     SDValue NegRHS =
7050         getNegatedExpression(RHS, DAG, LegalOps, OptForSize, CostRHS, Depth);
7051 
7052     // We're done with the handles.
7053     Handles.clear();
7054 
7055     if (!NegRHS || CostRHS > NegatibleCost::Neutral ||
7056         (CostLHS != NegatibleCost::Cheaper &&
7057          CostRHS != NegatibleCost::Cheaper)) {
7058       RemoveDeadNode(NegLHS);
7059       RemoveDeadNode(NegRHS);
7060       break;
7061     }
7062 
7063     Cost = std::min(CostLHS, CostRHS);
7064     return DAG.getSelect(DL, VT, Op.getOperand(0), NegLHS, NegRHS);
7065   }
7066   }
7067 
7068   return SDValue();
7069 }
7070 
7071 //===----------------------------------------------------------------------===//
7072 // Legalization Utilities
7073 //===----------------------------------------------------------------------===//
7074 
expandMUL_LOHI(unsigned Opcode,EVT VT,const SDLoc & dl,SDValue LHS,SDValue RHS,SmallVectorImpl<SDValue> & Result,EVT HiLoVT,SelectionDAG & DAG,MulExpansionKind Kind,SDValue LL,SDValue LH,SDValue RL,SDValue RH) const7075 bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, const SDLoc &dl,
7076                                     SDValue LHS, SDValue RHS,
7077                                     SmallVectorImpl<SDValue> &Result,
7078                                     EVT HiLoVT, SelectionDAG &DAG,
7079                                     MulExpansionKind Kind, SDValue LL,
7080                                     SDValue LH, SDValue RL, SDValue RH) const {
7081   assert(Opcode == ISD::MUL || Opcode == ISD::UMUL_LOHI ||
7082          Opcode == ISD::SMUL_LOHI);
7083 
7084   bool HasMULHS = (Kind == MulExpansionKind::Always) ||
7085                   isOperationLegalOrCustom(ISD::MULHS, HiLoVT);
7086   bool HasMULHU = (Kind == MulExpansionKind::Always) ||
7087                   isOperationLegalOrCustom(ISD::MULHU, HiLoVT);
7088   bool HasSMUL_LOHI = (Kind == MulExpansionKind::Always) ||
7089                       isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT);
7090   bool HasUMUL_LOHI = (Kind == MulExpansionKind::Always) ||
7091                       isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT);
7092 
7093   if (!HasMULHU && !HasMULHS && !HasUMUL_LOHI && !HasSMUL_LOHI)
7094     return false;
7095 
7096   unsigned OuterBitSize = VT.getScalarSizeInBits();
7097   unsigned InnerBitSize = HiLoVT.getScalarSizeInBits();
7098 
7099   // LL, LH, RL, and RH must be either all NULL or all set to a value.
7100   assert((LL.getNode() && LH.getNode() && RL.getNode() && RH.getNode()) ||
7101          (!LL.getNode() && !LH.getNode() && !RL.getNode() && !RH.getNode()));
7102 
7103   SDVTList VTs = DAG.getVTList(HiLoVT, HiLoVT);
7104   auto MakeMUL_LOHI = [&](SDValue L, SDValue R, SDValue &Lo, SDValue &Hi,
7105                           bool Signed) -> bool {
7106     if ((Signed && HasSMUL_LOHI) || (!Signed && HasUMUL_LOHI)) {
7107       Lo = DAG.getNode(Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI, dl, VTs, L, R);
7108       Hi = SDValue(Lo.getNode(), 1);
7109       return true;
7110     }
7111     if ((Signed && HasMULHS) || (!Signed && HasMULHU)) {
7112       Lo = DAG.getNode(ISD::MUL, dl, HiLoVT, L, R);
7113       Hi = DAG.getNode(Signed ? ISD::MULHS : ISD::MULHU, dl, HiLoVT, L, R);
7114       return true;
7115     }
7116     return false;
7117   };
7118 
7119   SDValue Lo, Hi;
7120 
7121   if (!LL.getNode() && !RL.getNode() &&
7122       isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) {
7123     LL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LHS);
7124     RL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, RHS);
7125   }
7126 
7127   if (!LL.getNode())
7128     return false;
7129 
7130   APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize);
7131   if (DAG.MaskedValueIsZero(LHS, HighMask) &&
7132       DAG.MaskedValueIsZero(RHS, HighMask)) {
7133     // The inputs are both zero-extended.
7134     if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) {
7135       Result.push_back(Lo);
7136       Result.push_back(Hi);
7137       if (Opcode != ISD::MUL) {
7138         SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
7139         Result.push_back(Zero);
7140         Result.push_back(Zero);
7141       }
7142       return true;
7143     }
7144   }
7145 
7146   if (!VT.isVector() && Opcode == ISD::MUL &&
7147       DAG.ComputeMaxSignificantBits(LHS) <= InnerBitSize &&
7148       DAG.ComputeMaxSignificantBits(RHS) <= InnerBitSize) {
7149     // The input values are both sign-extended.
7150     // TODO non-MUL case?
7151     if (MakeMUL_LOHI(LL, RL, Lo, Hi, true)) {
7152       Result.push_back(Lo);
7153       Result.push_back(Hi);
7154       return true;
7155     }
7156   }
7157 
7158   unsigned ShiftAmount = OuterBitSize - InnerBitSize;
7159   SDValue Shift = DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
7160 
7161   if (!LH.getNode() && !RH.getNode() &&
7162       isOperationLegalOrCustom(ISD::SRL, VT) &&
7163       isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) {
7164     LH = DAG.getNode(ISD::SRL, dl, VT, LHS, Shift);
7165     LH = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LH);
7166     RH = DAG.getNode(ISD::SRL, dl, VT, RHS, Shift);
7167     RH = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, RH);
7168   }
7169 
7170   if (!LH.getNode())
7171     return false;
7172 
7173   if (!MakeMUL_LOHI(LL, RL, Lo, Hi, false))
7174     return false;
7175 
7176   Result.push_back(Lo);
7177 
7178   if (Opcode == ISD::MUL) {
7179     RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH);
7180     LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL);
7181     Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH);
7182     Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH);
7183     Result.push_back(Hi);
7184     return true;
7185   }
7186 
7187   // Compute the full width result.
7188   auto Merge = [&](SDValue Lo, SDValue Hi) -> SDValue {
7189     Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo);
7190     Hi = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
7191     Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift);
7192     return DAG.getNode(ISD::OR, dl, VT, Lo, Hi);
7193   };
7194 
7195   SDValue Next = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
7196   if (!MakeMUL_LOHI(LL, RH, Lo, Hi, false))
7197     return false;
7198 
7199   // This is effectively the add part of a multiply-add of half-sized operands,
7200   // so it cannot overflow.
7201   Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
7202 
7203   if (!MakeMUL_LOHI(LH, RL, Lo, Hi, false))
7204     return false;
7205 
7206   SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
7207   EVT BoolType = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
7208 
7209   bool UseGlue = (isOperationLegalOrCustom(ISD::ADDC, VT) &&
7210                   isOperationLegalOrCustom(ISD::ADDE, VT));
7211   if (UseGlue)
7212     Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next,
7213                        Merge(Lo, Hi));
7214   else
7215     Next = DAG.getNode(ISD::ADDCARRY, dl, DAG.getVTList(VT, BoolType), Next,
7216                        Merge(Lo, Hi), DAG.getConstant(0, dl, BoolType));
7217 
7218   SDValue Carry = Next.getValue(1);
7219   Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
7220   Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
7221 
7222   if (!MakeMUL_LOHI(LH, RH, Lo, Hi, Opcode == ISD::SMUL_LOHI))
7223     return false;
7224 
7225   if (UseGlue)
7226     Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero,
7227                      Carry);
7228   else
7229     Hi = DAG.getNode(ISD::ADDCARRY, dl, DAG.getVTList(HiLoVT, BoolType), Hi,
7230                      Zero, Carry);
7231 
7232   Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
7233 
7234   if (Opcode == ISD::SMUL_LOHI) {
7235     SDValue NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
7236                                   DAG.getNode(ISD::ZERO_EXTEND, dl, VT, RL));
7237     Next = DAG.getSelectCC(dl, LH, Zero, NextSub, Next, ISD::SETLT);
7238 
7239     NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
7240                           DAG.getNode(ISD::ZERO_EXTEND, dl, VT, LL));
7241     Next = DAG.getSelectCC(dl, RH, Zero, NextSub, Next, ISD::SETLT);
7242   }
7243 
7244   Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
7245   Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
7246   Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
7247   return true;
7248 }
7249 
expandMUL(SDNode * N,SDValue & Lo,SDValue & Hi,EVT HiLoVT,SelectionDAG & DAG,MulExpansionKind Kind,SDValue LL,SDValue LH,SDValue RL,SDValue RH) const7250 bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
7251                                SelectionDAG &DAG, MulExpansionKind Kind,
7252                                SDValue LL, SDValue LH, SDValue RL,
7253                                SDValue RH) const {
7254   SmallVector<SDValue, 2> Result;
7255   bool Ok = expandMUL_LOHI(N->getOpcode(), N->getValueType(0), SDLoc(N),
7256                            N->getOperand(0), N->getOperand(1), Result, HiLoVT,
7257                            DAG, Kind, LL, LH, RL, RH);
7258   if (Ok) {
7259     assert(Result.size() == 2);
7260     Lo = Result[0];
7261     Hi = Result[1];
7262   }
7263   return Ok;
7264 }
7265 
7266 // Optimize unsigned division or remainder by constants for types twice as large
7267 // as a legal VT.
7268 //
7269 // If (1 << (BitWidth / 2)) % Constant == 1, then the remainder
7270 // can be computed
7271 // as:
7272 //   Sum += __builtin_uadd_overflow(Lo, High, &Sum);
7273 //   Remainder = Sum % Constant
7274 // This is based on "Remainder by Summing Digits" from Hacker's Delight.
7275 //
7276 // For division, we can compute the remainder using the algorithm described
7277 // above, subtract it from the dividend to get an exact multiple of Constant.
7278 // Then multiply that extact multiply by the multiplicative inverse modulo
7279 // (1 << (BitWidth / 2)) to get the quotient.
7280 
7281 // If Constant is even, we can shift right the dividend and the divisor by the
7282 // number of trailing zeros in Constant before applying the remainder algorithm.
7283 // If we're after the quotient, we can subtract this value from the shifted
7284 // dividend and multiply by the multiplicative inverse of the shifted divisor.
7285 // If we want the remainder, we shift the value left by the number of trailing
7286 // zeros and add the bits that were shifted out of the dividend.
expandDIVREMByConstant(SDNode * N,SmallVectorImpl<SDValue> & Result,EVT HiLoVT,SelectionDAG & DAG,SDValue LL,SDValue LH) const7287 bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7288                                             SmallVectorImpl<SDValue> &Result,
7289                                             EVT HiLoVT, SelectionDAG &DAG,
7290                                             SDValue LL, SDValue LH) const {
7291   unsigned Opcode = N->getOpcode();
7292   EVT VT = N->getValueType(0);
7293 
7294   // TODO: Support signed division/remainder.
7295   if (Opcode == ISD::SREM || Opcode == ISD::SDIV || Opcode == ISD::SDIVREM)
7296     return false;
7297   assert(
7298       (Opcode == ISD::UREM || Opcode == ISD::UDIV || Opcode == ISD::UDIVREM) &&
7299       "Unexpected opcode");
7300 
7301   auto *CN = dyn_cast<ConstantSDNode>(N->getOperand(1));
7302   if (!CN)
7303     return false;
7304 
7305   APInt Divisor = CN->getAPIntValue();
7306   unsigned BitWidth = Divisor.getBitWidth();
7307   unsigned HBitWidth = BitWidth / 2;
7308   assert(VT.getScalarSizeInBits() == BitWidth &&
7309          HiLoVT.getScalarSizeInBits() == HBitWidth && "Unexpected VTs");
7310 
7311   // Divisor needs to less than (1 << HBitWidth).
7312   APInt HalfMaxPlus1 = APInt::getOneBitSet(BitWidth, HBitWidth);
7313   if (Divisor.uge(HalfMaxPlus1))
7314     return false;
7315 
7316   // We depend on the UREM by constant optimization in DAGCombiner that requires
7317   // high multiply.
7318   if (!isOperationLegalOrCustom(ISD::MULHU, HiLoVT) &&
7319       !isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT))
7320     return false;
7321 
7322   // Don't expand if optimizing for size.
7323   if (DAG.shouldOptForSize())
7324     return false;
7325 
7326   // Early out for 0 or 1 divisors.
7327   if (Divisor.ule(1))
7328     return false;
7329 
7330   // If the divisor is even, shift it until it becomes odd.
7331   unsigned TrailingZeros = 0;
7332   if (!Divisor[0]) {
7333     TrailingZeros = Divisor.countTrailingZeros();
7334     Divisor.lshrInPlace(TrailingZeros);
7335   }
7336 
7337   SDLoc dl(N);
7338   SDValue Sum;
7339   SDValue PartialRem;
7340 
7341   // If (1 << HBitWidth) % divisor == 1, we can add the two halves together and
7342   // then add in the carry.
7343   // TODO: If we can't split it in half, we might be able to split into 3 or
7344   // more pieces using a smaller bit width.
7345   if (HalfMaxPlus1.urem(Divisor).isOneValue()) {
7346     assert(!LL == !LH && "Expected both input halves or no input halves!");
7347     if (!LL) {
7348       LL = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, N->getOperand(0),
7349                        DAG.getIntPtrConstant(0, dl));
7350       LH = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, N->getOperand(0),
7351                        DAG.getIntPtrConstant(1, dl));
7352     }
7353 
7354     // Shift the input by the number of TrailingZeros in the divisor. The
7355     // shifted out bits will be added to the remainder later.
7356     if (TrailingZeros) {
7357       // Save the shifted off bits if we need the remainder.
7358       if (Opcode != ISD::UDIV) {
7359         APInt Mask = APInt::getLowBitsSet(HBitWidth, TrailingZeros);
7360         PartialRem = DAG.getNode(ISD::AND, dl, HiLoVT, LL,
7361                                  DAG.getConstant(Mask, dl, HiLoVT));
7362       }
7363 
7364       LL = DAG.getNode(
7365           ISD::OR, dl, HiLoVT,
7366           DAG.getNode(ISD::SRL, dl, HiLoVT, LL,
7367                       DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl)),
7368           DAG.getNode(ISD::SHL, dl, HiLoVT, LH,
7369                       DAG.getShiftAmountConstant(HBitWidth - TrailingZeros,
7370                                                  HiLoVT, dl)));
7371       LH = DAG.getNode(ISD::SRL, dl, HiLoVT, LH,
7372                        DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
7373     }
7374 
7375     // Use addcarry if we can, otherwise use a compare to detect overflow.
7376     EVT SetCCType =
7377         getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), HiLoVT);
7378     if (isOperationLegalOrCustom(ISD::ADDCARRY, HiLoVT)) {
7379       SDVTList VTList = DAG.getVTList(HiLoVT, SetCCType);
7380       Sum = DAG.getNode(ISD::UADDO, dl, VTList, LL, LH);
7381       Sum = DAG.getNode(ISD::ADDCARRY, dl, VTList, Sum,
7382                         DAG.getConstant(0, dl, HiLoVT), Sum.getValue(1));
7383     } else {
7384       Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, LL, LH);
7385       SDValue Carry = DAG.getSetCC(dl, SetCCType, Sum, LL, ISD::SETULT);
7386       // If the boolean for the target is 0 or 1, we can add the setcc result
7387       // directly.
7388       if (getBooleanContents(HiLoVT) ==
7389           TargetLoweringBase::ZeroOrOneBooleanContent)
7390         Carry = DAG.getZExtOrTrunc(Carry, dl, HiLoVT);
7391       else
7392         Carry = DAG.getSelect(dl, HiLoVT, Carry, DAG.getConstant(1, dl, HiLoVT),
7393                               DAG.getConstant(0, dl, HiLoVT));
7394       Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, Sum, Carry);
7395     }
7396   }
7397 
7398   // If we didn't find a sum, we can't do the expansion.
7399   if (!Sum)
7400     return false;
7401 
7402   // Perform a HiLoVT urem on the Sum using truncated divisor.
7403   SDValue RemL =
7404       DAG.getNode(ISD::UREM, dl, HiLoVT, Sum,
7405                   DAG.getConstant(Divisor.trunc(HBitWidth), dl, HiLoVT));
7406   SDValue RemH = DAG.getConstant(0, dl, HiLoVT);
7407 
7408   if (Opcode != ISD::UREM) {
7409     // Subtract the remainder from the shifted dividend.
7410     SDValue Dividend = DAG.getNode(ISD::BUILD_PAIR, dl, VT, LL, LH);
7411     SDValue Rem = DAG.getNode(ISD::BUILD_PAIR, dl, VT, RemL, RemH);
7412 
7413     Dividend = DAG.getNode(ISD::SUB, dl, VT, Dividend, Rem);
7414 
7415     // Multiply by the multiplicative inverse of the divisor modulo
7416     // (1 << BitWidth).
7417     APInt Mod = APInt::getSignedMinValue(BitWidth + 1);
7418     APInt MulFactor = Divisor.zext(BitWidth + 1);
7419     MulFactor = MulFactor.multiplicativeInverse(Mod);
7420     MulFactor = MulFactor.trunc(BitWidth);
7421 
7422     SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
7423                                    DAG.getConstant(MulFactor, dl, VT));
7424 
7425     // Split the quotient into low and high parts.
7426     SDValue QuotL = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
7427                                 DAG.getIntPtrConstant(0, dl));
7428     SDValue QuotH = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, HiLoVT, Quotient,
7429                                 DAG.getIntPtrConstant(1, dl));
7430     Result.push_back(QuotL);
7431     Result.push_back(QuotH);
7432   }
7433 
7434   if (Opcode != ISD::UDIV) {
7435     // If we shifted the input, shift the remainder left and add the bits we
7436     // shifted off the input.
7437     if (TrailingZeros) {
7438       APInt Mask = APInt::getLowBitsSet(HBitWidth, TrailingZeros);
7439       RemL = DAG.getNode(ISD::SHL, dl, HiLoVT, RemL,
7440                          DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
7441       RemL = DAG.getNode(ISD::ADD, dl, HiLoVT, RemL, PartialRem);
7442     }
7443     Result.push_back(RemL);
7444     Result.push_back(DAG.getConstant(0, dl, HiLoVT));
7445   }
7446 
7447   return true;
7448 }
7449 
7450 // Check that (every element of) Z is undef or not an exact multiple of BW.
isNonZeroModBitWidthOrUndef(SDValue Z,unsigned BW)7451 static bool isNonZeroModBitWidthOrUndef(SDValue Z, unsigned BW) {
7452   return ISD::matchUnaryPredicate(
7453       Z,
7454       [=](ConstantSDNode *C) { return !C || C->getAPIntValue().urem(BW) != 0; },
7455       true);
7456 }
7457 
expandVPFunnelShift(SDNode * Node,SelectionDAG & DAG)7458 static SDValue expandVPFunnelShift(SDNode *Node, SelectionDAG &DAG) {
7459   EVT VT = Node->getValueType(0);
7460   SDValue ShX, ShY;
7461   SDValue ShAmt, InvShAmt;
7462   SDValue X = Node->getOperand(0);
7463   SDValue Y = Node->getOperand(1);
7464   SDValue Z = Node->getOperand(2);
7465   SDValue Mask = Node->getOperand(3);
7466   SDValue VL = Node->getOperand(4);
7467 
7468   unsigned BW = VT.getScalarSizeInBits();
7469   bool IsFSHL = Node->getOpcode() == ISD::VP_FSHL;
7470   SDLoc DL(SDValue(Node, 0));
7471 
7472   EVT ShVT = Z.getValueType();
7473   if (isNonZeroModBitWidthOrUndef(Z, BW)) {
7474     // fshl: X << C | Y >> (BW - C)
7475     // fshr: X << (BW - C) | Y >> C
7476     // where C = Z % BW is not zero
7477     SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
7478     ShAmt = DAG.getNode(ISD::VP_UREM, DL, ShVT, Z, BitWidthC, Mask, VL);
7479     InvShAmt = DAG.getNode(ISD::VP_SUB, DL, ShVT, BitWidthC, ShAmt, Mask, VL);
7480     ShX = DAG.getNode(ISD::VP_SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt, Mask,
7481                       VL);
7482     ShY = DAG.getNode(ISD::VP_LSHR, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt, Mask,
7483                       VL);
7484   } else {
7485     // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
7486     // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
7487     SDValue BitMask = DAG.getConstant(BW - 1, DL, ShVT);
7488     if (isPowerOf2_32(BW)) {
7489       // Z % BW -> Z & (BW - 1)
7490       ShAmt = DAG.getNode(ISD::VP_AND, DL, ShVT, Z, BitMask, Mask, VL);
7491       // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
7492       SDValue NotZ = DAG.getNode(ISD::VP_XOR, DL, ShVT, Z,
7493                                  DAG.getAllOnesConstant(DL, ShVT), Mask, VL);
7494       InvShAmt = DAG.getNode(ISD::VP_AND, DL, ShVT, NotZ, BitMask, Mask, VL);
7495     } else {
7496       SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
7497       ShAmt = DAG.getNode(ISD::VP_UREM, DL, ShVT, Z, BitWidthC, Mask, VL);
7498       InvShAmt = DAG.getNode(ISD::VP_SUB, DL, ShVT, BitMask, ShAmt, Mask, VL);
7499     }
7500 
7501     SDValue One = DAG.getConstant(1, DL, ShVT);
7502     if (IsFSHL) {
7503       ShX = DAG.getNode(ISD::VP_SHL, DL, VT, X, ShAmt, Mask, VL);
7504       SDValue ShY1 = DAG.getNode(ISD::VP_LSHR, DL, VT, Y, One, Mask, VL);
7505       ShY = DAG.getNode(ISD::VP_LSHR, DL, VT, ShY1, InvShAmt, Mask, VL);
7506     } else {
7507       SDValue ShX1 = DAG.getNode(ISD::VP_SHL, DL, VT, X, One, Mask, VL);
7508       ShX = DAG.getNode(ISD::VP_SHL, DL, VT, ShX1, InvShAmt, Mask, VL);
7509       ShY = DAG.getNode(ISD::VP_LSHR, DL, VT, Y, ShAmt, Mask, VL);
7510     }
7511   }
7512   return DAG.getNode(ISD::VP_OR, DL, VT, ShX, ShY, Mask, VL);
7513 }
7514 
expandFunnelShift(SDNode * Node,SelectionDAG & DAG) const7515 SDValue TargetLowering::expandFunnelShift(SDNode *Node,
7516                                           SelectionDAG &DAG) const {
7517   if (Node->isVPOpcode())
7518     return expandVPFunnelShift(Node, DAG);
7519 
7520   EVT VT = Node->getValueType(0);
7521 
7522   if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SHL, VT) ||
7523                         !isOperationLegalOrCustom(ISD::SRL, VT) ||
7524                         !isOperationLegalOrCustom(ISD::SUB, VT) ||
7525                         !isOperationLegalOrCustomOrPromote(ISD::OR, VT)))
7526     return SDValue();
7527 
7528   SDValue X = Node->getOperand(0);
7529   SDValue Y = Node->getOperand(1);
7530   SDValue Z = Node->getOperand(2);
7531 
7532   unsigned BW = VT.getScalarSizeInBits();
7533   bool IsFSHL = Node->getOpcode() == ISD::FSHL;
7534   SDLoc DL(SDValue(Node, 0));
7535 
7536   EVT ShVT = Z.getValueType();
7537 
7538   // If a funnel shift in the other direction is more supported, use it.
7539   unsigned RevOpcode = IsFSHL ? ISD::FSHR : ISD::FSHL;
7540   if (!isOperationLegalOrCustom(Node->getOpcode(), VT) &&
7541       isOperationLegalOrCustom(RevOpcode, VT) && isPowerOf2_32(BW)) {
7542     if (isNonZeroModBitWidthOrUndef(Z, BW)) {
7543       // fshl X, Y, Z -> fshr X, Y, -Z
7544       // fshr X, Y, Z -> fshl X, Y, -Z
7545       SDValue Zero = DAG.getConstant(0, DL, ShVT);
7546       Z = DAG.getNode(ISD::SUB, DL, VT, Zero, Z);
7547     } else {
7548       // fshl X, Y, Z -> fshr (srl X, 1), (fshr X, Y, 1), ~Z
7549       // fshr X, Y, Z -> fshl (fshl X, Y, 1), (shl Y, 1), ~Z
7550       SDValue One = DAG.getConstant(1, DL, ShVT);
7551       if (IsFSHL) {
7552         Y = DAG.getNode(RevOpcode, DL, VT, X, Y, One);
7553         X = DAG.getNode(ISD::SRL, DL, VT, X, One);
7554       } else {
7555         X = DAG.getNode(RevOpcode, DL, VT, X, Y, One);
7556         Y = DAG.getNode(ISD::SHL, DL, VT, Y, One);
7557       }
7558       Z = DAG.getNOT(DL, Z, ShVT);
7559     }
7560     return DAG.getNode(RevOpcode, DL, VT, X, Y, Z);
7561   }
7562 
7563   SDValue ShX, ShY;
7564   SDValue ShAmt, InvShAmt;
7565   if (isNonZeroModBitWidthOrUndef(Z, BW)) {
7566     // fshl: X << C | Y >> (BW - C)
7567     // fshr: X << (BW - C) | Y >> C
7568     // where C = Z % BW is not zero
7569     SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
7570     ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
7571     InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt);
7572     ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt);
7573     ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt);
7574   } else {
7575     // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
7576     // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
7577     SDValue Mask = DAG.getConstant(BW - 1, DL, ShVT);
7578     if (isPowerOf2_32(BW)) {
7579       // Z % BW -> Z & (BW - 1)
7580       ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask);
7581       // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
7582       InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask);
7583     } else {
7584       SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
7585       ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
7586       InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt);
7587     }
7588 
7589     SDValue One = DAG.getConstant(1, DL, ShVT);
7590     if (IsFSHL) {
7591       ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt);
7592       SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One);
7593       ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt);
7594     } else {
7595       SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One);
7596       ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt);
7597       ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt);
7598     }
7599   }
7600   return DAG.getNode(ISD::OR, DL, VT, ShX, ShY);
7601 }
7602 
7603 // TODO: Merge with expandFunnelShift.
expandROT(SDNode * Node,bool AllowVectorOps,SelectionDAG & DAG) const7604 SDValue TargetLowering::expandROT(SDNode *Node, bool AllowVectorOps,
7605                                   SelectionDAG &DAG) const {
7606   EVT VT = Node->getValueType(0);
7607   unsigned EltSizeInBits = VT.getScalarSizeInBits();
7608   bool IsLeft = Node->getOpcode() == ISD::ROTL;
7609   SDValue Op0 = Node->getOperand(0);
7610   SDValue Op1 = Node->getOperand(1);
7611   SDLoc DL(SDValue(Node, 0));
7612 
7613   EVT ShVT = Op1.getValueType();
7614   SDValue Zero = DAG.getConstant(0, DL, ShVT);
7615 
7616   // If a rotate in the other direction is more supported, use it.
7617   unsigned RevRot = IsLeft ? ISD::ROTR : ISD::ROTL;
7618   if (!isOperationLegalOrCustom(Node->getOpcode(), VT) &&
7619       isOperationLegalOrCustom(RevRot, VT) && isPowerOf2_32(EltSizeInBits)) {
7620     SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1);
7621     return DAG.getNode(RevRot, DL, VT, Op0, Sub);
7622   }
7623 
7624   if (!AllowVectorOps && VT.isVector() &&
7625       (!isOperationLegalOrCustom(ISD::SHL, VT) ||
7626        !isOperationLegalOrCustom(ISD::SRL, VT) ||
7627        !isOperationLegalOrCustom(ISD::SUB, VT) ||
7628        !isOperationLegalOrCustomOrPromote(ISD::OR, VT) ||
7629        !isOperationLegalOrCustomOrPromote(ISD::AND, VT)))
7630     return SDValue();
7631 
7632   unsigned ShOpc = IsLeft ? ISD::SHL : ISD::SRL;
7633   unsigned HsOpc = IsLeft ? ISD::SRL : ISD::SHL;
7634   SDValue BitWidthMinusOneC = DAG.getConstant(EltSizeInBits - 1, DL, ShVT);
7635   SDValue ShVal;
7636   SDValue HsVal;
7637   if (isPowerOf2_32(EltSizeInBits)) {
7638     // (rotl x, c) -> x << (c & (w - 1)) | x >> (-c & (w - 1))
7639     // (rotr x, c) -> x >> (c & (w - 1)) | x << (-c & (w - 1))
7640     SDValue NegOp1 = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1);
7641     SDValue ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Op1, BitWidthMinusOneC);
7642     ShVal = DAG.getNode(ShOpc, DL, VT, Op0, ShAmt);
7643     SDValue HsAmt = DAG.getNode(ISD::AND, DL, ShVT, NegOp1, BitWidthMinusOneC);
7644     HsVal = DAG.getNode(HsOpc, DL, VT, Op0, HsAmt);
7645   } else {
7646     // (rotl x, c) -> x << (c % w) | x >> 1 >> (w - 1 - (c % w))
7647     // (rotr x, c) -> x >> (c % w) | x << 1 << (w - 1 - (c % w))
7648     SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT);
7649     SDValue ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Op1, BitWidthC);
7650     ShVal = DAG.getNode(ShOpc, DL, VT, Op0, ShAmt);
7651     SDValue HsAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthMinusOneC, ShAmt);
7652     SDValue One = DAG.getConstant(1, DL, ShVT);
7653     HsVal =
7654         DAG.getNode(HsOpc, DL, VT, DAG.getNode(HsOpc, DL, VT, Op0, One), HsAmt);
7655   }
7656   return DAG.getNode(ISD::OR, DL, VT, ShVal, HsVal);
7657 }
7658 
expandShiftParts(SDNode * Node,SDValue & Lo,SDValue & Hi,SelectionDAG & DAG) const7659 void TargetLowering::expandShiftParts(SDNode *Node, SDValue &Lo, SDValue &Hi,
7660                                       SelectionDAG &DAG) const {
7661   assert(Node->getNumOperands() == 3 && "Not a double-shift!");
7662   EVT VT = Node->getValueType(0);
7663   unsigned VTBits = VT.getScalarSizeInBits();
7664   assert(isPowerOf2_32(VTBits) && "Power-of-two integer type expected");
7665 
7666   bool IsSHL = Node->getOpcode() == ISD::SHL_PARTS;
7667   bool IsSRA = Node->getOpcode() == ISD::SRA_PARTS;
7668   SDValue ShOpLo = Node->getOperand(0);
7669   SDValue ShOpHi = Node->getOperand(1);
7670   SDValue ShAmt = Node->getOperand(2);
7671   EVT ShAmtVT = ShAmt.getValueType();
7672   EVT ShAmtCCVT =
7673       getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ShAmtVT);
7674   SDLoc dl(Node);
7675 
7676   // ISD::FSHL and ISD::FSHR have defined overflow behavior but ISD::SHL and
7677   // ISD::SRA/L nodes haven't. Insert an AND to be safe, it's usually optimized
7678   // away during isel.
7679   SDValue SafeShAmt = DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
7680                                   DAG.getConstant(VTBits - 1, dl, ShAmtVT));
7681   SDValue Tmp1 = IsSRA ? DAG.getNode(ISD::SRA, dl, VT, ShOpHi,
7682                                      DAG.getConstant(VTBits - 1, dl, ShAmtVT))
7683                        : DAG.getConstant(0, dl, VT);
7684 
7685   SDValue Tmp2, Tmp3;
7686   if (IsSHL) {
7687     Tmp2 = DAG.getNode(ISD::FSHL, dl, VT, ShOpHi, ShOpLo, ShAmt);
7688     Tmp3 = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, SafeShAmt);
7689   } else {
7690     Tmp2 = DAG.getNode(ISD::FSHR, dl, VT, ShOpHi, ShOpLo, ShAmt);
7691     Tmp3 = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, dl, VT, ShOpHi, SafeShAmt);
7692   }
7693 
7694   // If the shift amount is larger or equal than the width of a part we don't
7695   // use the result from the FSHL/FSHR. Insert a test and select the appropriate
7696   // values for large shift amounts.
7697   SDValue AndNode = DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
7698                                 DAG.getConstant(VTBits, dl, ShAmtVT));
7699   SDValue Cond = DAG.getSetCC(dl, ShAmtCCVT, AndNode,
7700                               DAG.getConstant(0, dl, ShAmtVT), ISD::SETNE);
7701 
7702   if (IsSHL) {
7703     Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2);
7704     Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3);
7705   } else {
7706     Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2);
7707     Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3);
7708   }
7709 }
7710 
expandFP_TO_SINT(SDNode * Node,SDValue & Result,SelectionDAG & DAG) const7711 bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result,
7712                                       SelectionDAG &DAG) const {
7713   unsigned OpNo = Node->isStrictFPOpcode() ? 1 : 0;
7714   SDValue Src = Node->getOperand(OpNo);
7715   EVT SrcVT = Src.getValueType();
7716   EVT DstVT = Node->getValueType(0);
7717   SDLoc dl(SDValue(Node, 0));
7718 
7719   // FIXME: Only f32 to i64 conversions are supported.
7720   if (SrcVT != MVT::f32 || DstVT != MVT::i64)
7721     return false;
7722 
7723   if (Node->isStrictFPOpcode())
7724     // When a NaN is converted to an integer a trap is allowed. We can't
7725     // use this expansion here because it would eliminate that trap. Other
7726     // traps are also allowed and cannot be eliminated. See
7727     // IEEE 754-2008 sec 5.8.
7728     return false;
7729 
7730   // Expand f32 -> i64 conversion
7731   // This algorithm comes from compiler-rt's implementation of fixsfdi:
7732   // https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/fixsfdi.c
7733   unsigned SrcEltBits = SrcVT.getScalarSizeInBits();
7734   EVT IntVT = SrcVT.changeTypeToInteger();
7735   EVT IntShVT = getShiftAmountTy(IntVT, DAG.getDataLayout());
7736 
7737   SDValue ExponentMask = DAG.getConstant(0x7F800000, dl, IntVT);
7738   SDValue ExponentLoBit = DAG.getConstant(23, dl, IntVT);
7739   SDValue Bias = DAG.getConstant(127, dl, IntVT);
7740   SDValue SignMask = DAG.getConstant(APInt::getSignMask(SrcEltBits), dl, IntVT);
7741   SDValue SignLowBit = DAG.getConstant(SrcEltBits - 1, dl, IntVT);
7742   SDValue MantissaMask = DAG.getConstant(0x007FFFFF, dl, IntVT);
7743 
7744   SDValue Bits = DAG.getNode(ISD::BITCAST, dl, IntVT, Src);
7745 
7746   SDValue ExponentBits = DAG.getNode(
7747       ISD::SRL, dl, IntVT, DAG.getNode(ISD::AND, dl, IntVT, Bits, ExponentMask),
7748       DAG.getZExtOrTrunc(ExponentLoBit, dl, IntShVT));
7749   SDValue Exponent = DAG.getNode(ISD::SUB, dl, IntVT, ExponentBits, Bias);
7750 
7751   SDValue Sign = DAG.getNode(ISD::SRA, dl, IntVT,
7752                              DAG.getNode(ISD::AND, dl, IntVT, Bits, SignMask),
7753                              DAG.getZExtOrTrunc(SignLowBit, dl, IntShVT));
7754   Sign = DAG.getSExtOrTrunc(Sign, dl, DstVT);
7755 
7756   SDValue R = DAG.getNode(ISD::OR, dl, IntVT,
7757                           DAG.getNode(ISD::AND, dl, IntVT, Bits, MantissaMask),
7758                           DAG.getConstant(0x00800000, dl, IntVT));
7759 
7760   R = DAG.getZExtOrTrunc(R, dl, DstVT);
7761 
7762   R = DAG.getSelectCC(
7763       dl, Exponent, ExponentLoBit,
7764       DAG.getNode(ISD::SHL, dl, DstVT, R,
7765                   DAG.getZExtOrTrunc(
7766                       DAG.getNode(ISD::SUB, dl, IntVT, Exponent, ExponentLoBit),
7767                       dl, IntShVT)),
7768       DAG.getNode(ISD::SRL, dl, DstVT, R,
7769                   DAG.getZExtOrTrunc(
7770                       DAG.getNode(ISD::SUB, dl, IntVT, ExponentLoBit, Exponent),
7771                       dl, IntShVT)),
7772       ISD::SETGT);
7773 
7774   SDValue Ret = DAG.getNode(ISD::SUB, dl, DstVT,
7775                             DAG.getNode(ISD::XOR, dl, DstVT, R, Sign), Sign);
7776 
7777   Result = DAG.getSelectCC(dl, Exponent, DAG.getConstant(0, dl, IntVT),
7778                            DAG.getConstant(0, dl, DstVT), Ret, ISD::SETLT);
7779   return true;
7780 }
7781 
expandFP_TO_UINT(SDNode * Node,SDValue & Result,SDValue & Chain,SelectionDAG & DAG) const7782 bool TargetLowering::expandFP_TO_UINT(SDNode *Node, SDValue &Result,
7783                                       SDValue &Chain,
7784                                       SelectionDAG &DAG) const {
7785   SDLoc dl(SDValue(Node, 0));
7786   unsigned OpNo = Node->isStrictFPOpcode() ? 1 : 0;
7787   SDValue Src = Node->getOperand(OpNo);
7788 
7789   EVT SrcVT = Src.getValueType();
7790   EVT DstVT = Node->getValueType(0);
7791   EVT SetCCVT =
7792       getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
7793   EVT DstSetCCVT =
7794       getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), DstVT);
7795 
7796   // Only expand vector types if we have the appropriate vector bit operations.
7797   unsigned SIntOpcode = Node->isStrictFPOpcode() ? ISD::STRICT_FP_TO_SINT :
7798                                                    ISD::FP_TO_SINT;
7799   if (DstVT.isVector() && (!isOperationLegalOrCustom(SIntOpcode, DstVT) ||
7800                            !isOperationLegalOrCustomOrPromote(ISD::XOR, SrcVT)))
7801     return false;
7802 
7803   // If the maximum float value is smaller then the signed integer range,
7804   // the destination signmask can't be represented by the float, so we can
7805   // just use FP_TO_SINT directly.
7806   const fltSemantics &APFSem = DAG.EVTToAPFloatSemantics(SrcVT);
7807   APFloat APF(APFSem, APInt::getZero(SrcVT.getScalarSizeInBits()));
7808   APInt SignMask = APInt::getSignMask(DstVT.getScalarSizeInBits());
7809   if (APFloat::opOverflow &
7810       APF.convertFromAPInt(SignMask, false, APFloat::rmNearestTiesToEven)) {
7811     if (Node->isStrictFPOpcode()) {
7812       Result = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, { DstVT, MVT::Other },
7813                            { Node->getOperand(0), Src });
7814       Chain = Result.getValue(1);
7815     } else
7816       Result = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Src);
7817     return true;
7818   }
7819 
7820   // Don't expand it if there isn't cheap fsub instruction.
7821   if (!isOperationLegalOrCustom(
7822           Node->isStrictFPOpcode() ? ISD::STRICT_FSUB : ISD::FSUB, SrcVT))
7823     return false;
7824 
7825   SDValue Cst = DAG.getConstantFP(APF, dl, SrcVT);
7826   SDValue Sel;
7827 
7828   if (Node->isStrictFPOpcode()) {
7829     Sel = DAG.getSetCC(dl, SetCCVT, Src, Cst, ISD::SETLT,
7830                        Node->getOperand(0), /*IsSignaling*/ true);
7831     Chain = Sel.getValue(1);
7832   } else {
7833     Sel = DAG.getSetCC(dl, SetCCVT, Src, Cst, ISD::SETLT);
7834   }
7835 
7836   bool Strict = Node->isStrictFPOpcode() ||
7837                 shouldUseStrictFP_TO_INT(SrcVT, DstVT, /*IsSigned*/ false);
7838 
7839   if (Strict) {
7840     // Expand based on maximum range of FP_TO_SINT, if the value exceeds the
7841     // signmask then offset (the result of which should be fully representable).
7842     // Sel = Src < 0x8000000000000000
7843     // FltOfs = select Sel, 0, 0x8000000000000000
7844     // IntOfs = select Sel, 0, 0x8000000000000000
7845     // Result = fp_to_sint(Src - FltOfs) ^ IntOfs
7846 
7847     // TODO: Should any fast-math-flags be set for the FSUB?
7848     SDValue FltOfs = DAG.getSelect(dl, SrcVT, Sel,
7849                                    DAG.getConstantFP(0.0, dl, SrcVT), Cst);
7850     Sel = DAG.getBoolExtOrTrunc(Sel, dl, DstSetCCVT, DstVT);
7851     SDValue IntOfs = DAG.getSelect(dl, DstVT, Sel,
7852                                    DAG.getConstant(0, dl, DstVT),
7853                                    DAG.getConstant(SignMask, dl, DstVT));
7854     SDValue SInt;
7855     if (Node->isStrictFPOpcode()) {
7856       SDValue Val = DAG.getNode(ISD::STRICT_FSUB, dl, { SrcVT, MVT::Other },
7857                                 { Chain, Src, FltOfs });
7858       SInt = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, { DstVT, MVT::Other },
7859                          { Val.getValue(1), Val });
7860       Chain = SInt.getValue(1);
7861     } else {
7862       SDValue Val = DAG.getNode(ISD::FSUB, dl, SrcVT, Src, FltOfs);
7863       SInt = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Val);
7864     }
7865     Result = DAG.getNode(ISD::XOR, dl, DstVT, SInt, IntOfs);
7866   } else {
7867     // Expand based on maximum range of FP_TO_SINT:
7868     // True = fp_to_sint(Src)
7869     // False = 0x8000000000000000 + fp_to_sint(Src - 0x8000000000000000)
7870     // Result = select (Src < 0x8000000000000000), True, False
7871 
7872     SDValue True = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Src);
7873     // TODO: Should any fast-math-flags be set for the FSUB?
7874     SDValue False = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT,
7875                                 DAG.getNode(ISD::FSUB, dl, SrcVT, Src, Cst));
7876     False = DAG.getNode(ISD::XOR, dl, DstVT, False,
7877                         DAG.getConstant(SignMask, dl, DstVT));
7878     Sel = DAG.getBoolExtOrTrunc(Sel, dl, DstSetCCVT, DstVT);
7879     Result = DAG.getSelect(dl, DstVT, Sel, True, False);
7880   }
7881   return true;
7882 }
7883 
expandUINT_TO_FP(SDNode * Node,SDValue & Result,SDValue & Chain,SelectionDAG & DAG) const7884 bool TargetLowering::expandUINT_TO_FP(SDNode *Node, SDValue &Result,
7885                                       SDValue &Chain,
7886                                       SelectionDAG &DAG) const {
7887   // This transform is not correct for converting 0 when rounding mode is set
7888   // to round toward negative infinity which will produce -0.0. So disable under
7889   // strictfp.
7890   if (Node->isStrictFPOpcode())
7891     return false;
7892 
7893   SDValue Src = Node->getOperand(0);
7894   EVT SrcVT = Src.getValueType();
7895   EVT DstVT = Node->getValueType(0);
7896 
7897   if (SrcVT.getScalarType() != MVT::i64 || DstVT.getScalarType() != MVT::f64)
7898     return false;
7899 
7900   // Only expand vector types if we have the appropriate vector bit operations.
7901   if (SrcVT.isVector() && (!isOperationLegalOrCustom(ISD::SRL, SrcVT) ||
7902                            !isOperationLegalOrCustom(ISD::FADD, DstVT) ||
7903                            !isOperationLegalOrCustom(ISD::FSUB, DstVT) ||
7904                            !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) ||
7905                            !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT)))
7906     return false;
7907 
7908   SDLoc dl(SDValue(Node, 0));
7909   EVT ShiftVT = getShiftAmountTy(SrcVT, DAG.getDataLayout());
7910 
7911   // Implementation of unsigned i64 to f64 following the algorithm in
7912   // __floatundidf in compiler_rt.  This implementation performs rounding
7913   // correctly in all rounding modes with the exception of converting 0
7914   // when rounding toward negative infinity. In that case the fsub will produce
7915   // -0.0. This will be added to +0.0 and produce -0.0 which is incorrect.
7916   SDValue TwoP52 = DAG.getConstant(UINT64_C(0x4330000000000000), dl, SrcVT);
7917   SDValue TwoP84PlusTwoP52 = DAG.getConstantFP(
7918       BitsToDouble(UINT64_C(0x4530000000100000)), dl, DstVT);
7919   SDValue TwoP84 = DAG.getConstant(UINT64_C(0x4530000000000000), dl, SrcVT);
7920   SDValue LoMask = DAG.getConstant(UINT64_C(0x00000000FFFFFFFF), dl, SrcVT);
7921   SDValue HiShift = DAG.getConstant(32, dl, ShiftVT);
7922 
7923   SDValue Lo = DAG.getNode(ISD::AND, dl, SrcVT, Src, LoMask);
7924   SDValue Hi = DAG.getNode(ISD::SRL, dl, SrcVT, Src, HiShift);
7925   SDValue LoOr = DAG.getNode(ISD::OR, dl, SrcVT, Lo, TwoP52);
7926   SDValue HiOr = DAG.getNode(ISD::OR, dl, SrcVT, Hi, TwoP84);
7927   SDValue LoFlt = DAG.getBitcast(DstVT, LoOr);
7928   SDValue HiFlt = DAG.getBitcast(DstVT, HiOr);
7929   SDValue HiSub =
7930       DAG.getNode(ISD::FSUB, dl, DstVT, HiFlt, TwoP84PlusTwoP52);
7931   Result = DAG.getNode(ISD::FADD, dl, DstVT, LoFlt, HiSub);
7932   return true;
7933 }
7934 
7935 SDValue
createSelectForFMINNUM_FMAXNUM(SDNode * Node,SelectionDAG & DAG) const7936 TargetLowering::createSelectForFMINNUM_FMAXNUM(SDNode *Node,
7937                                                SelectionDAG &DAG) const {
7938   unsigned Opcode = Node->getOpcode();
7939   assert((Opcode == ISD::FMINNUM || Opcode == ISD::FMAXNUM ||
7940           Opcode == ISD::STRICT_FMINNUM || Opcode == ISD::STRICT_FMAXNUM) &&
7941          "Wrong opcode");
7942 
7943   if (Node->getFlags().hasNoNaNs()) {
7944     ISD::CondCode Pred = Opcode == ISD::FMINNUM ? ISD::SETLT : ISD::SETGT;
7945     SDValue Op1 = Node->getOperand(0);
7946     SDValue Op2 = Node->getOperand(1);
7947     SDValue SelCC = DAG.getSelectCC(SDLoc(Node), Op1, Op2, Op1, Op2, Pred);
7948     // Copy FMF flags, but always set the no-signed-zeros flag
7949     // as this is implied by the FMINNUM/FMAXNUM semantics.
7950     SDNodeFlags Flags = Node->getFlags();
7951     Flags.setNoSignedZeros(true);
7952     SelCC->setFlags(Flags);
7953     return SelCC;
7954   }
7955 
7956   return SDValue();
7957 }
7958 
expandFMINNUM_FMAXNUM(SDNode * Node,SelectionDAG & DAG) const7959 SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,
7960                                               SelectionDAG &DAG) const {
7961   SDLoc dl(Node);
7962   unsigned NewOp = Node->getOpcode() == ISD::FMINNUM ?
7963     ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
7964   EVT VT = Node->getValueType(0);
7965 
7966   if (VT.isScalableVector())
7967     report_fatal_error(
7968         "Expanding fminnum/fmaxnum for scalable vectors is undefined.");
7969 
7970   if (isOperationLegalOrCustom(NewOp, VT)) {
7971     SDValue Quiet0 = Node->getOperand(0);
7972     SDValue Quiet1 = Node->getOperand(1);
7973 
7974     if (!Node->getFlags().hasNoNaNs()) {
7975       // Insert canonicalizes if it's possible we need to quiet to get correct
7976       // sNaN behavior.
7977       if (!DAG.isKnownNeverSNaN(Quiet0)) {
7978         Quiet0 = DAG.getNode(ISD::FCANONICALIZE, dl, VT, Quiet0,
7979                              Node->getFlags());
7980       }
7981       if (!DAG.isKnownNeverSNaN(Quiet1)) {
7982         Quiet1 = DAG.getNode(ISD::FCANONICALIZE, dl, VT, Quiet1,
7983                              Node->getFlags());
7984       }
7985     }
7986 
7987     return DAG.getNode(NewOp, dl, VT, Quiet0, Quiet1, Node->getFlags());
7988   }
7989 
7990   // If the target has FMINIMUM/FMAXIMUM but not FMINNUM/FMAXNUM use that
7991   // instead if there are no NaNs.
7992   if (Node->getFlags().hasNoNaNs()) {
7993     unsigned IEEE2018Op =
7994         Node->getOpcode() == ISD::FMINNUM ? ISD::FMINIMUM : ISD::FMAXIMUM;
7995     if (isOperationLegalOrCustom(IEEE2018Op, VT)) {
7996       return DAG.getNode(IEEE2018Op, dl, VT, Node->getOperand(0),
7997                          Node->getOperand(1), Node->getFlags());
7998     }
7999   }
8000 
8001   if (SDValue SelCC = createSelectForFMINNUM_FMAXNUM(Node, DAG))
8002     return SelCC;
8003 
8004   return SDValue();
8005 }
8006 
expandIS_FPCLASS(EVT ResultVT,SDValue Op,unsigned Test,SDNodeFlags Flags,const SDLoc & DL,SelectionDAG & DAG) const8007 SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
8008                                          unsigned Test, SDNodeFlags Flags,
8009                                          const SDLoc &DL,
8010                                          SelectionDAG &DAG) const {
8011   EVT OperandVT = Op.getValueType();
8012   assert(OperandVT.isFloatingPoint());
8013 
8014   // Degenerated cases.
8015   if (Test == 0)
8016     return DAG.getBoolConstant(false, DL, ResultVT, OperandVT);
8017   if ((Test & fcAllFlags) == fcAllFlags)
8018     return DAG.getBoolConstant(true, DL, ResultVT, OperandVT);
8019 
8020   // PPC double double is a pair of doubles, of which the higher part determines
8021   // the value class.
8022   if (OperandVT == MVT::ppcf128) {
8023     Op = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::f64, Op,
8024                      DAG.getConstant(1, DL, MVT::i32));
8025     OperandVT = MVT::f64;
8026   }
8027 
8028   // Some checks may be represented as inversion of simpler check, for example
8029   // "inf|normal|subnormal|zero" => !"nan".
8030   bool IsInverted = false;
8031   if (unsigned InvertedCheck = getInvertedFPClassTest(Test)) {
8032     IsInverted = true;
8033     Test = InvertedCheck;
8034   }
8035 
8036   // Floating-point type properties.
8037   EVT ScalarFloatVT = OperandVT.getScalarType();
8038   const Type *FloatTy = ScalarFloatVT.getTypeForEVT(*DAG.getContext());
8039   const llvm::fltSemantics &Semantics = FloatTy->getFltSemantics();
8040   bool IsF80 = (ScalarFloatVT == MVT::f80);
8041 
8042   // Some checks can be implemented using float comparisons, if floating point
8043   // exceptions are ignored.
8044   if (Flags.hasNoFPExcept() &&
8045       isOperationLegalOrCustom(ISD::SETCC, OperandVT.getScalarType())) {
8046     if (Test == fcZero)
8047       return DAG.getSetCC(DL, ResultVT, Op,
8048                           DAG.getConstantFP(0.0, DL, OperandVT),
8049                           IsInverted ? ISD::SETUNE : ISD::SETOEQ);
8050     if (Test == fcNan)
8051       return DAG.getSetCC(DL, ResultVT, Op, Op,
8052                           IsInverted ? ISD::SETO : ISD::SETUO);
8053   }
8054 
8055   // In the general case use integer operations.
8056   unsigned BitSize = OperandVT.getScalarSizeInBits();
8057   EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), BitSize);
8058   if (OperandVT.isVector())
8059     IntVT = EVT::getVectorVT(*DAG.getContext(), IntVT,
8060                              OperandVT.getVectorElementCount());
8061   SDValue OpAsInt = DAG.getBitcast(IntVT, Op);
8062 
8063   // Various masks.
8064   APInt SignBit = APInt::getSignMask(BitSize);
8065   APInt ValueMask = APInt::getSignedMaxValue(BitSize);     // All bits but sign.
8066   APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
8067   const unsigned ExplicitIntBitInF80 = 63;
8068   APInt ExpMask = Inf;
8069   if (IsF80)
8070     ExpMask.clearBit(ExplicitIntBitInF80);
8071   APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
8072   APInt QNaNBitMask =
8073       APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
8074   APInt InvertionMask = APInt::getAllOnesValue(ResultVT.getScalarSizeInBits());
8075 
8076   SDValue ValueMaskV = DAG.getConstant(ValueMask, DL, IntVT);
8077   SDValue SignBitV = DAG.getConstant(SignBit, DL, IntVT);
8078   SDValue ExpMaskV = DAG.getConstant(ExpMask, DL, IntVT);
8079   SDValue ZeroV = DAG.getConstant(0, DL, IntVT);
8080   SDValue InfV = DAG.getConstant(Inf, DL, IntVT);
8081   SDValue ResultInvertionMask = DAG.getConstant(InvertionMask, DL, ResultVT);
8082 
8083   SDValue Res;
8084   const auto appendResult = [&](SDValue PartialRes) {
8085     if (PartialRes) {
8086       if (Res)
8087         Res = DAG.getNode(ISD::OR, DL, ResultVT, Res, PartialRes);
8088       else
8089         Res = PartialRes;
8090     }
8091   };
8092 
8093   SDValue IntBitIsSetV; // Explicit integer bit in f80 mantissa is set.
8094   const auto getIntBitIsSet = [&]() -> SDValue {
8095     if (!IntBitIsSetV) {
8096       APInt IntBitMask(BitSize, 0);
8097       IntBitMask.setBit(ExplicitIntBitInF80);
8098       SDValue IntBitMaskV = DAG.getConstant(IntBitMask, DL, IntVT);
8099       SDValue IntBitV = DAG.getNode(ISD::AND, DL, IntVT, OpAsInt, IntBitMaskV);
8100       IntBitIsSetV = DAG.getSetCC(DL, ResultVT, IntBitV, ZeroV, ISD::SETNE);
8101     }
8102     return IntBitIsSetV;
8103   };
8104 
8105   // Split the value into sign bit and absolute value.
8106   SDValue AbsV = DAG.getNode(ISD::AND, DL, IntVT, OpAsInt, ValueMaskV);
8107   SDValue SignV = DAG.getSetCC(DL, ResultVT, OpAsInt,
8108                                DAG.getConstant(0.0, DL, IntVT), ISD::SETLT);
8109 
8110   // Tests that involve more than one class should be processed first.
8111   SDValue PartialRes;
8112 
8113   if (IsF80)
8114     ; // Detect finite numbers of f80 by checking individual classes because
8115       // they have different settings of the explicit integer bit.
8116   else if ((Test & fcFinite) == fcFinite) {
8117     // finite(V) ==> abs(V) < exp_mask
8118     PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, ExpMaskV, ISD::SETLT);
8119     Test &= ~fcFinite;
8120   } else if ((Test & fcFinite) == fcPosFinite) {
8121     // finite(V) && V > 0 ==> V < exp_mask
8122     PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, ExpMaskV, ISD::SETULT);
8123     Test &= ~fcPosFinite;
8124   } else if ((Test & fcFinite) == fcNegFinite) {
8125     // finite(V) && V < 0 ==> abs(V) < exp_mask && signbit == 1
8126     PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, ExpMaskV, ISD::SETLT);
8127     PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, SignV);
8128     Test &= ~fcNegFinite;
8129   }
8130   appendResult(PartialRes);
8131 
8132   // Check for individual classes.
8133 
8134   if (unsigned PartialCheck = Test & fcZero) {
8135     if (PartialCheck == fcPosZero)
8136       PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, ZeroV, ISD::SETEQ);
8137     else if (PartialCheck == fcZero)
8138       PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, ZeroV, ISD::SETEQ);
8139     else // ISD::fcNegZero
8140       PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, SignBitV, ISD::SETEQ);
8141     appendResult(PartialRes);
8142   }
8143 
8144   if (unsigned PartialCheck = Test & fcInf) {
8145     if (PartialCheck == fcPosInf)
8146       PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, InfV, ISD::SETEQ);
8147     else if (PartialCheck == fcInf)
8148       PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, InfV, ISD::SETEQ);
8149     else { // ISD::fcNegInf
8150       APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
8151       SDValue NegInfV = DAG.getConstant(NegInf, DL, IntVT);
8152       PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, NegInfV, ISD::SETEQ);
8153     }
8154     appendResult(PartialRes);
8155   }
8156 
8157   if (unsigned PartialCheck = Test & fcNan) {
8158     APInt InfWithQnanBit = Inf | QNaNBitMask;
8159     SDValue InfWithQnanBitV = DAG.getConstant(InfWithQnanBit, DL, IntVT);
8160     if (PartialCheck == fcNan) {
8161       // isnan(V) ==> abs(V) > int(inf)
8162       PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, InfV, ISD::SETGT);
8163       if (IsF80) {
8164         // Recognize unsupported values as NaNs for compatibility with glibc.
8165         // In them (exp(V)==0) == int_bit.
8166         SDValue ExpBits = DAG.getNode(ISD::AND, DL, IntVT, AbsV, ExpMaskV);
8167         SDValue ExpIsZero =
8168             DAG.getSetCC(DL, ResultVT, ExpBits, ZeroV, ISD::SETEQ);
8169         SDValue IsPseudo =
8170             DAG.getSetCC(DL, ResultVT, getIntBitIsSet(), ExpIsZero, ISD::SETEQ);
8171         PartialRes = DAG.getNode(ISD::OR, DL, ResultVT, PartialRes, IsPseudo);
8172       }
8173     } else if (PartialCheck == fcQNan) {
8174       // isquiet(V) ==> abs(V) >= (unsigned(Inf) | quiet_bit)
8175       PartialRes =
8176           DAG.getSetCC(DL, ResultVT, AbsV, InfWithQnanBitV, ISD::SETGE);
8177     } else { // ISD::fcSNan
8178       // issignaling(V) ==> abs(V) > unsigned(Inf) &&
8179       //                    abs(V) < (unsigned(Inf) | quiet_bit)
8180       SDValue IsNan = DAG.getSetCC(DL, ResultVT, AbsV, InfV, ISD::SETGT);
8181       SDValue IsNotQnan =
8182           DAG.getSetCC(DL, ResultVT, AbsV, InfWithQnanBitV, ISD::SETLT);
8183       PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, IsNan, IsNotQnan);
8184     }
8185     appendResult(PartialRes);
8186   }
8187 
8188   if (unsigned PartialCheck = Test & fcSubnormal) {
8189     // issubnormal(V) ==> unsigned(abs(V) - 1) < (all mantissa bits set)
8190     // issubnormal(V) && V>0 ==> unsigned(V - 1) < (all mantissa bits set)
8191     SDValue V = (PartialCheck == fcPosSubnormal) ? OpAsInt : AbsV;
8192     SDValue MantissaV = DAG.getConstant(AllOneMantissa, DL, IntVT);
8193     SDValue VMinusOneV =
8194         DAG.getNode(ISD::SUB, DL, IntVT, V, DAG.getConstant(1, DL, IntVT));
8195     PartialRes = DAG.getSetCC(DL, ResultVT, VMinusOneV, MantissaV, ISD::SETULT);
8196     if (PartialCheck == fcNegSubnormal)
8197       PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, SignV);
8198     appendResult(PartialRes);
8199   }
8200 
8201   if (unsigned PartialCheck = Test & fcNormal) {
8202     // isnormal(V) ==> (0 < exp < max_exp) ==> (unsigned(exp-1) < (max_exp-1))
8203     APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
8204     SDValue ExpLSBV = DAG.getConstant(ExpLSB, DL, IntVT);
8205     SDValue ExpMinus1 = DAG.getNode(ISD::SUB, DL, IntVT, AbsV, ExpLSBV);
8206     APInt ExpLimit = ExpMask - ExpLSB;
8207     SDValue ExpLimitV = DAG.getConstant(ExpLimit, DL, IntVT);
8208     PartialRes = DAG.getSetCC(DL, ResultVT, ExpMinus1, ExpLimitV, ISD::SETULT);
8209     if (PartialCheck == fcNegNormal)
8210       PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, SignV);
8211     else if (PartialCheck == fcPosNormal) {
8212       SDValue PosSignV =
8213           DAG.getNode(ISD::XOR, DL, ResultVT, SignV, ResultInvertionMask);
8214       PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, PosSignV);
8215     }
8216     if (IsF80)
8217       PartialRes =
8218           DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, getIntBitIsSet());
8219     appendResult(PartialRes);
8220   }
8221 
8222   if (!Res)
8223     return DAG.getConstant(IsInverted, DL, ResultVT);
8224   if (IsInverted)
8225     Res = DAG.getNode(ISD::XOR, DL, ResultVT, Res, ResultInvertionMask);
8226   return Res;
8227 }
8228 
8229 // Only expand vector types if we have the appropriate vector bit operations.
canExpandVectorCTPOP(const TargetLowering & TLI,EVT VT)8230 static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
8231   assert(VT.isVector() && "Expected vector type");
8232   unsigned Len = VT.getScalarSizeInBits();
8233   return TLI.isOperationLegalOrCustom(ISD::ADD, VT) &&
8234          TLI.isOperationLegalOrCustom(ISD::SUB, VT) &&
8235          TLI.isOperationLegalOrCustom(ISD::SRL, VT) &&
8236          (Len == 8 || TLI.isOperationLegalOrCustom(ISD::MUL, VT)) &&
8237          TLI.isOperationLegalOrCustomOrPromote(ISD::AND, VT);
8238 }
8239 
expandCTPOP(SDNode * Node,SelectionDAG & DAG) const8240 SDValue TargetLowering::expandCTPOP(SDNode *Node, SelectionDAG &DAG) const {
8241   SDLoc dl(Node);
8242   EVT VT = Node->getValueType(0);
8243   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
8244   SDValue Op = Node->getOperand(0);
8245   unsigned Len = VT.getScalarSizeInBits();
8246   assert(VT.isInteger() && "CTPOP not implemented for this type.");
8247 
8248   // TODO: Add support for irregular type lengths.
8249   if (!(Len <= 128 && Len % 8 == 0))
8250     return SDValue();
8251 
8252   // Only expand vector types if we have the appropriate vector bit operations.
8253   if (VT.isVector() && !canExpandVectorCTPOP(*this, VT))
8254     return SDValue();
8255 
8256   // This is the "best" algorithm from
8257   // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
8258   SDValue Mask55 =
8259       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl, VT);
8260   SDValue Mask33 =
8261       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl, VT);
8262   SDValue Mask0F =
8263       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl, VT);
8264 
8265   // v = v - ((v >> 1) & 0x55555555...)
8266   Op = DAG.getNode(ISD::SUB, dl, VT, Op,
8267                    DAG.getNode(ISD::AND, dl, VT,
8268                                DAG.getNode(ISD::SRL, dl, VT, Op,
8269                                            DAG.getConstant(1, dl, ShVT)),
8270                                Mask55));
8271   // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
8272   Op = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::AND, dl, VT, Op, Mask33),
8273                    DAG.getNode(ISD::AND, dl, VT,
8274                                DAG.getNode(ISD::SRL, dl, VT, Op,
8275                                            DAG.getConstant(2, dl, ShVT)),
8276                                Mask33));
8277   // v = (v + (v >> 4)) & 0x0F0F0F0F...
8278   Op = DAG.getNode(ISD::AND, dl, VT,
8279                    DAG.getNode(ISD::ADD, dl, VT, Op,
8280                                DAG.getNode(ISD::SRL, dl, VT, Op,
8281                                            DAG.getConstant(4, dl, ShVT))),
8282                    Mask0F);
8283 
8284   if (Len <= 8)
8285     return Op;
8286 
8287   // Avoid the multiply if we only have 2 bytes to add.
8288   // TODO: Only doing this for scalars because vectors weren't as obviously
8289   // improved.
8290   if (Len == 16 && !VT.isVector()) {
8291     // v = (v + (v >> 8)) & 0x00FF;
8292     return DAG.getNode(ISD::AND, dl, VT,
8293                      DAG.getNode(ISD::ADD, dl, VT, Op,
8294                                  DAG.getNode(ISD::SRL, dl, VT, Op,
8295                                              DAG.getConstant(8, dl, ShVT))),
8296                      DAG.getConstant(0xFF, dl, VT));
8297   }
8298 
8299   // v = (v * 0x01010101...) >> (Len - 8)
8300   SDValue Mask01 =
8301       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x01)), dl, VT);
8302   return DAG.getNode(ISD::SRL, dl, VT,
8303                      DAG.getNode(ISD::MUL, dl, VT, Op, Mask01),
8304                      DAG.getConstant(Len - 8, dl, ShVT));
8305 }
8306 
expandVPCTPOP(SDNode * Node,SelectionDAG & DAG) const8307 SDValue TargetLowering::expandVPCTPOP(SDNode *Node, SelectionDAG &DAG) const {
8308   SDLoc dl(Node);
8309   EVT VT = Node->getValueType(0);
8310   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
8311   SDValue Op = Node->getOperand(0);
8312   SDValue Mask = Node->getOperand(1);
8313   SDValue VL = Node->getOperand(2);
8314   unsigned Len = VT.getScalarSizeInBits();
8315   assert(VT.isInteger() && "VP_CTPOP not implemented for this type.");
8316 
8317   // TODO: Add support for irregular type lengths.
8318   if (!(Len <= 128 && Len % 8 == 0))
8319     return SDValue();
8320 
8321   // This is same algorithm of expandCTPOP from
8322   // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
8323   SDValue Mask55 =
8324       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl, VT);
8325   SDValue Mask33 =
8326       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl, VT);
8327   SDValue Mask0F =
8328       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl, VT);
8329 
8330   SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5;
8331 
8332   // v = v - ((v >> 1) & 0x55555555...)
8333   Tmp1 = DAG.getNode(ISD::VP_AND, dl, VT,
8334                      DAG.getNode(ISD::VP_LSHR, dl, VT, Op,
8335                                  DAG.getConstant(1, dl, ShVT), Mask, VL),
8336                      Mask55, Mask, VL);
8337   Op = DAG.getNode(ISD::VP_SUB, dl, VT, Op, Tmp1, Mask, VL);
8338 
8339   // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
8340   Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Op, Mask33, Mask, VL);
8341   Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT,
8342                      DAG.getNode(ISD::VP_LSHR, dl, VT, Op,
8343                                  DAG.getConstant(2, dl, ShVT), Mask, VL),
8344                      Mask33, Mask, VL);
8345   Op = DAG.getNode(ISD::VP_ADD, dl, VT, Tmp2, Tmp3, Mask, VL);
8346 
8347   // v = (v + (v >> 4)) & 0x0F0F0F0F...
8348   Tmp4 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(4, dl, ShVT),
8349                      Mask, VL),
8350   Tmp5 = DAG.getNode(ISD::VP_ADD, dl, VT, Op, Tmp4, Mask, VL);
8351   Op = DAG.getNode(ISD::VP_AND, dl, VT, Tmp5, Mask0F, Mask, VL);
8352 
8353   if (Len <= 8)
8354     return Op;
8355 
8356   // v = (v * 0x01010101...) >> (Len - 8)
8357   SDValue Mask01 =
8358       DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x01)), dl, VT);
8359   return DAG.getNode(ISD::VP_LSHR, dl, VT,
8360                      DAG.getNode(ISD::VP_MUL, dl, VT, Op, Mask01, Mask, VL),
8361                      DAG.getConstant(Len - 8, dl, ShVT), Mask, VL);
8362 }
8363 
expandCTLZ(SDNode * Node,SelectionDAG & DAG) const8364 SDValue TargetLowering::expandCTLZ(SDNode *Node, SelectionDAG &DAG) const {
8365   SDLoc dl(Node);
8366   EVT VT = Node->getValueType(0);
8367   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
8368   SDValue Op = Node->getOperand(0);
8369   unsigned NumBitsPerElt = VT.getScalarSizeInBits();
8370 
8371   // If the non-ZERO_UNDEF version is supported we can use that instead.
8372   if (Node->getOpcode() == ISD::CTLZ_ZERO_UNDEF &&
8373       isOperationLegalOrCustom(ISD::CTLZ, VT))
8374     return DAG.getNode(ISD::CTLZ, dl, VT, Op);
8375 
8376   // If the ZERO_UNDEF version is supported use that and handle the zero case.
8377   if (isOperationLegalOrCustom(ISD::CTLZ_ZERO_UNDEF, VT)) {
8378     EVT SetCCVT =
8379         getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
8380     SDValue CTLZ = DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, VT, Op);
8381     SDValue Zero = DAG.getConstant(0, dl, VT);
8382     SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ);
8383     return DAG.getSelect(dl, VT, SrcIsZero,
8384                          DAG.getConstant(NumBitsPerElt, dl, VT), CTLZ);
8385   }
8386 
8387   // Only expand vector types if we have the appropriate vector bit operations.
8388   // This includes the operations needed to expand CTPOP if it isn't supported.
8389   if (VT.isVector() && (!isPowerOf2_32(NumBitsPerElt) ||
8390                         (!isOperationLegalOrCustom(ISD::CTPOP, VT) &&
8391                          !canExpandVectorCTPOP(*this, VT)) ||
8392                         !isOperationLegalOrCustom(ISD::SRL, VT) ||
8393                         !isOperationLegalOrCustomOrPromote(ISD::OR, VT)))
8394     return SDValue();
8395 
8396   // for now, we do this:
8397   // x = x | (x >> 1);
8398   // x = x | (x >> 2);
8399   // ...
8400   // x = x | (x >>16);
8401   // x = x | (x >>32); // for 64-bit input
8402   // return popcount(~x);
8403   //
8404   // Ref: "Hacker's Delight" by Henry Warren
8405   for (unsigned i = 0; (1U << i) < NumBitsPerElt; ++i) {
8406     SDValue Tmp = DAG.getConstant(1ULL << i, dl, ShVT);
8407     Op = DAG.getNode(ISD::OR, dl, VT, Op,
8408                      DAG.getNode(ISD::SRL, dl, VT, Op, Tmp));
8409   }
8410   Op = DAG.getNOT(dl, Op, VT);
8411   return DAG.getNode(ISD::CTPOP, dl, VT, Op);
8412 }
8413 
expandVPCTLZ(SDNode * Node,SelectionDAG & DAG) const8414 SDValue TargetLowering::expandVPCTLZ(SDNode *Node, SelectionDAG &DAG) const {
8415   SDLoc dl(Node);
8416   EVT VT = Node->getValueType(0);
8417   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
8418   SDValue Op = Node->getOperand(0);
8419   SDValue Mask = Node->getOperand(1);
8420   SDValue VL = Node->getOperand(2);
8421   unsigned NumBitsPerElt = VT.getScalarSizeInBits();
8422 
8423   // do this:
8424   // x = x | (x >> 1);
8425   // x = x | (x >> 2);
8426   // ...
8427   // x = x | (x >>16);
8428   // x = x | (x >>32); // for 64-bit input
8429   // return popcount(~x);
8430   for (unsigned i = 0; (1U << i) < NumBitsPerElt; ++i) {
8431     SDValue Tmp = DAG.getConstant(1ULL << i, dl, ShVT);
8432     Op = DAG.getNode(ISD::VP_OR, dl, VT, Op,
8433                      DAG.getNode(ISD::VP_LSHR, dl, VT, Op, Tmp, Mask, VL), Mask,
8434                      VL);
8435   }
8436   Op = DAG.getNode(ISD::VP_XOR, dl, VT, Op, DAG.getConstant(-1, dl, VT), Mask,
8437                    VL);
8438   return DAG.getNode(ISD::VP_CTPOP, dl, VT, Op, Mask, VL);
8439 }
8440 
CTTZTableLookup(SDNode * Node,SelectionDAG & DAG,const SDLoc & DL,EVT VT,SDValue Op,unsigned BitWidth) const8441 SDValue TargetLowering::CTTZTableLookup(SDNode *Node, SelectionDAG &DAG,
8442                                         const SDLoc &DL, EVT VT, SDValue Op,
8443                                         unsigned BitWidth) const {
8444   if (BitWidth != 32 && BitWidth != 64)
8445     return SDValue();
8446   APInt DeBruijn = BitWidth == 32 ? APInt(32, 0x077CB531U)
8447                                   : APInt(64, 0x0218A392CD3D5DBFULL);
8448   const DataLayout &TD = DAG.getDataLayout();
8449   MachinePointerInfo PtrInfo =
8450       MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
8451   unsigned ShiftAmt = BitWidth - Log2_32(BitWidth);
8452   SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Op);
8453   SDValue Lookup = DAG.getNode(
8454       ISD::SRL, DL, VT,
8455       DAG.getNode(ISD::MUL, DL, VT, DAG.getNode(ISD::AND, DL, VT, Op, Neg),
8456                   DAG.getConstant(DeBruijn, DL, VT)),
8457       DAG.getConstant(ShiftAmt, DL, VT));
8458   Lookup = DAG.getSExtOrTrunc(Lookup, DL, getPointerTy(TD));
8459 
8460   SmallVector<uint8_t> Table(BitWidth, 0);
8461   for (unsigned i = 0; i < BitWidth; i++) {
8462     APInt Shl = DeBruijn.shl(i);
8463     APInt Lshr = Shl.lshr(ShiftAmt);
8464     Table[Lshr.getZExtValue()] = i;
8465   }
8466 
8467   // Create a ConstantArray in Constant Pool
8468   auto *CA = ConstantDataArray::get(*DAG.getContext(), Table);
8469   SDValue CPIdx = DAG.getConstantPool(CA, getPointerTy(TD),
8470                                       TD.getPrefTypeAlign(CA->getType()));
8471   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, DL, VT, DAG.getEntryNode(),
8472                                    DAG.getMemBasePlusOffset(CPIdx, Lookup, DL),
8473                                    PtrInfo, MVT::i8);
8474   if (Node->getOpcode() == ISD::CTTZ_ZERO_UNDEF)
8475     return ExtLoad;
8476 
8477   EVT SetCCVT =
8478       getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
8479   SDValue Zero = DAG.getConstant(0, DL, VT);
8480   SDValue SrcIsZero = DAG.getSetCC(DL, SetCCVT, Op, Zero, ISD::SETEQ);
8481   return DAG.getSelect(DL, VT, SrcIsZero,
8482                        DAG.getConstant(BitWidth, DL, VT), ExtLoad);
8483 }
8484 
expandCTTZ(SDNode * Node,SelectionDAG & DAG) const8485 SDValue TargetLowering::expandCTTZ(SDNode *Node, SelectionDAG &DAG) const {
8486   SDLoc dl(Node);
8487   EVT VT = Node->getValueType(0);
8488   SDValue Op = Node->getOperand(0);
8489   unsigned NumBitsPerElt = VT.getScalarSizeInBits();
8490 
8491   // If the non-ZERO_UNDEF version is supported we can use that instead.
8492   if (Node->getOpcode() == ISD::CTTZ_ZERO_UNDEF &&
8493       isOperationLegalOrCustom(ISD::CTTZ, VT))
8494     return DAG.getNode(ISD::CTTZ, dl, VT, Op);
8495 
8496   // If the ZERO_UNDEF version is supported use that and handle the zero case.
8497   if (isOperationLegalOrCustom(ISD::CTTZ_ZERO_UNDEF, VT)) {
8498     EVT SetCCVT =
8499         getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
8500     SDValue CTTZ = DAG.getNode(ISD::CTTZ_ZERO_UNDEF, dl, VT, Op);
8501     SDValue Zero = DAG.getConstant(0, dl, VT);
8502     SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ);
8503     return DAG.getSelect(dl, VT, SrcIsZero,
8504                          DAG.getConstant(NumBitsPerElt, dl, VT), CTTZ);
8505   }
8506 
8507   // Only expand vector types if we have the appropriate vector bit operations.
8508   // This includes the operations needed to expand CTPOP if it isn't supported.
8509   if (VT.isVector() && (!isPowerOf2_32(NumBitsPerElt) ||
8510                         (!isOperationLegalOrCustom(ISD::CTPOP, VT) &&
8511                          !isOperationLegalOrCustom(ISD::CTLZ, VT) &&
8512                          !canExpandVectorCTPOP(*this, VT)) ||
8513                         !isOperationLegalOrCustom(ISD::SUB, VT) ||
8514                         !isOperationLegalOrCustomOrPromote(ISD::AND, VT) ||
8515                         !isOperationLegalOrCustomOrPromote(ISD::XOR, VT)))
8516     return SDValue();
8517 
8518   // Emit Table Lookup if ISD::CTLZ and ISD::CTPOP are not legal.
8519   if (!VT.isVector() && isOperationExpand(ISD::CTPOP, VT) &&
8520       !isOperationLegal(ISD::CTLZ, VT))
8521     if (SDValue V = CTTZTableLookup(Node, DAG, dl, VT, Op, NumBitsPerElt))
8522       return V;
8523 
8524   // for now, we use: { return popcount(~x & (x - 1)); }
8525   // unless the target has ctlz but not ctpop, in which case we use:
8526   // { return 32 - nlz(~x & (x-1)); }
8527   // Ref: "Hacker's Delight" by Henry Warren
8528   SDValue Tmp = DAG.getNode(
8529       ISD::AND, dl, VT, DAG.getNOT(dl, Op, VT),
8530       DAG.getNode(ISD::SUB, dl, VT, Op, DAG.getConstant(1, dl, VT)));
8531 
8532   // If ISD::CTLZ is legal and CTPOP isn't, then do that instead.
8533   if (isOperationLegal(ISD::CTLZ, VT) && !isOperationLegal(ISD::CTPOP, VT)) {
8534     return DAG.getNode(ISD::SUB, dl, VT, DAG.getConstant(NumBitsPerElt, dl, VT),
8535                        DAG.getNode(ISD::CTLZ, dl, VT, Tmp));
8536   }
8537 
8538   return DAG.getNode(ISD::CTPOP, dl, VT, Tmp);
8539 }
8540 
expandVPCTTZ(SDNode * Node,SelectionDAG & DAG) const8541 SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
8542   SDValue Op = Node->getOperand(0);
8543   SDValue Mask = Node->getOperand(1);
8544   SDValue VL = Node->getOperand(2);
8545   SDLoc dl(Node);
8546   EVT VT = Node->getValueType(0);
8547 
8548   // Same as the vector part of expandCTTZ, use: popcount(~x & (x - 1))
8549   SDValue Not = DAG.getNode(ISD::VP_XOR, dl, VT, Op,
8550                             DAG.getConstant(-1, dl, VT), Mask, VL);
8551   SDValue MinusOne = DAG.getNode(ISD::VP_SUB, dl, VT, Op,
8552                                  DAG.getConstant(1, dl, VT), Mask, VL);
8553   SDValue Tmp = DAG.getNode(ISD::VP_AND, dl, VT, Not, MinusOne, Mask, VL);
8554   return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
8555 }
8556 
expandABS(SDNode * N,SelectionDAG & DAG,bool IsNegative) const8557 SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
8558                                   bool IsNegative) const {
8559   SDLoc dl(N);
8560   EVT VT = N->getValueType(0);
8561   EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
8562   SDValue Op = N->getOperand(0);
8563 
8564   // abs(x) -> smax(x,sub(0,x))
8565   if (!IsNegative && isOperationLegal(ISD::SUB, VT) &&
8566       isOperationLegal(ISD::SMAX, VT)) {
8567     SDValue Zero = DAG.getConstant(0, dl, VT);
8568     return DAG.getNode(ISD::SMAX, dl, VT, Op,
8569                        DAG.getNode(ISD::SUB, dl, VT, Zero, Op));
8570   }
8571 
8572   // abs(x) -> umin(x,sub(0,x))
8573   if (!IsNegative && isOperationLegal(ISD::SUB, VT) &&
8574       isOperationLegal(ISD::UMIN, VT)) {
8575     SDValue Zero = DAG.getConstant(0, dl, VT);
8576     Op = DAG.getFreeze(Op);
8577     return DAG.getNode(ISD::UMIN, dl, VT, Op,
8578                        DAG.getNode(ISD::SUB, dl, VT, Zero, Op));
8579   }
8580 
8581   // 0 - abs(x) -> smin(x, sub(0,x))
8582   if (IsNegative && isOperationLegal(ISD::SUB, VT) &&
8583       isOperationLegal(ISD::SMIN, VT)) {
8584     Op = DAG.getFreeze(Op);
8585     SDValue Zero = DAG.getConstant(0, dl, VT);
8586     return DAG.getNode(ISD::SMIN, dl, VT, Op,
8587                        DAG.getNode(ISD::SUB, dl, VT, Zero, Op));
8588   }
8589 
8590   // Only expand vector types if we have the appropriate vector operations.
8591   if (VT.isVector() &&
8592       (!isOperationLegalOrCustom(ISD::SRA, VT) ||
8593        (!IsNegative && !isOperationLegalOrCustom(ISD::ADD, VT)) ||
8594        (IsNegative && !isOperationLegalOrCustom(ISD::SUB, VT)) ||
8595        !isOperationLegalOrCustomOrPromote(ISD::XOR, VT)))
8596     return SDValue();
8597 
8598   Op = DAG.getFreeze(Op);
8599   SDValue Shift =
8600       DAG.getNode(ISD::SRA, dl, VT, Op,
8601                   DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, ShVT));
8602   SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, Op, Shift);
8603 
8604   // abs(x) -> Y = sra (X, size(X)-1); sub (xor (X, Y), Y)
8605   if (!IsNegative)
8606     return DAG.getNode(ISD::SUB, dl, VT, Xor, Shift);
8607 
8608   // 0 - abs(x) -> Y = sra (X, size(X)-1); sub (Y, xor (X, Y))
8609   return DAG.getNode(ISD::SUB, dl, VT, Shift, Xor);
8610 }
8611 
expandBSWAP(SDNode * N,SelectionDAG & DAG) const8612 SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
8613   SDLoc dl(N);
8614   EVT VT = N->getValueType(0);
8615   SDValue Op = N->getOperand(0);
8616 
8617   if (!VT.isSimple())
8618     return SDValue();
8619 
8620   EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
8621   SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5, Tmp6, Tmp7, Tmp8;
8622   switch (VT.getSimpleVT().getScalarType().SimpleTy) {
8623   default:
8624     return SDValue();
8625   case MVT::i16:
8626     // Use a rotate by 8. This can be further expanded if necessary.
8627     return DAG.getNode(ISD::ROTL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));
8628   case MVT::i32:
8629     Tmp4 = DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(24, dl, SHVT));
8630     Tmp3 = DAG.getNode(ISD::AND, dl, VT, Op,
8631                        DAG.getConstant(0xFF00, dl, VT));
8632     Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(8, dl, SHVT));
8633     Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));
8634     Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(0xFF00, dl, VT));
8635     Tmp1 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(24, dl, SHVT));
8636     Tmp4 = DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp3);
8637     Tmp2 = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp1);
8638     return DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp2);
8639   case MVT::i64:
8640     Tmp8 = DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(56, dl, SHVT));
8641     Tmp7 = DAG.getNode(ISD::AND, dl, VT, Op,
8642                        DAG.getConstant(255ULL<<8, dl, VT));
8643     Tmp7 = DAG.getNode(ISD::SHL, dl, VT, Tmp7, DAG.getConstant(40, dl, SHVT));
8644     Tmp6 = DAG.getNode(ISD::AND, dl, VT, Op,
8645                        DAG.getConstant(255ULL<<16, dl, VT));
8646     Tmp6 = DAG.getNode(ISD::SHL, dl, VT, Tmp6, DAG.getConstant(24, dl, SHVT));
8647     Tmp5 = DAG.getNode(ISD::AND, dl, VT, Op,
8648                        DAG.getConstant(255ULL<<24, dl, VT));
8649     Tmp5 = DAG.getNode(ISD::SHL, dl, VT, Tmp5, DAG.getConstant(8, dl, SHVT));
8650     Tmp4 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));
8651     Tmp4 = DAG.getNode(ISD::AND, dl, VT, Tmp4,
8652                        DAG.getConstant(255ULL<<24, dl, VT));
8653     Tmp3 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(24, dl, SHVT));
8654     Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp3,
8655                        DAG.getConstant(255ULL<<16, dl, VT));
8656     Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(40, dl, SHVT));
8657     Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2,
8658                        DAG.getConstant(255ULL<<8, dl, VT));
8659     Tmp1 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(56, dl, SHVT));
8660     Tmp8 = DAG.getNode(ISD::OR, dl, VT, Tmp8, Tmp7);
8661     Tmp6 = DAG.getNode(ISD::OR, dl, VT, Tmp6, Tmp5);
8662     Tmp4 = DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp3);
8663     Tmp2 = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp1);
8664     Tmp8 = DAG.getNode(ISD::OR, dl, VT, Tmp8, Tmp6);
8665     Tmp4 = DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp2);
8666     return DAG.getNode(ISD::OR, dl, VT, Tmp8, Tmp4);
8667   }
8668 }
8669 
expandVPBSWAP(SDNode * N,SelectionDAG & DAG) const8670 SDValue TargetLowering::expandVPBSWAP(SDNode *N, SelectionDAG &DAG) const {
8671   SDLoc dl(N);
8672   EVT VT = N->getValueType(0);
8673   SDValue Op = N->getOperand(0);
8674   SDValue Mask = N->getOperand(1);
8675   SDValue EVL = N->getOperand(2);
8676 
8677   if (!VT.isSimple())
8678     return SDValue();
8679 
8680   EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
8681   SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5, Tmp6, Tmp7, Tmp8;
8682   switch (VT.getSimpleVT().getScalarType().SimpleTy) {
8683   default:
8684     return SDValue();
8685   case MVT::i16:
8686     Tmp1 = DAG.getNode(ISD::VP_SHL, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
8687                        Mask, EVL);
8688     Tmp2 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
8689                        Mask, EVL);
8690     return DAG.getNode(ISD::VP_OR, dl, VT, Tmp1, Tmp2, Mask, EVL);
8691   case MVT::i32:
8692     Tmp4 = DAG.getNode(ISD::VP_SHL, dl, VT, Op, DAG.getConstant(24, dl, SHVT),
8693                        Mask, EVL);
8694     Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Op, DAG.getConstant(0xFF00, dl, VT),
8695                        Mask, EVL);
8696     Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(8, dl, SHVT),
8697                        Mask, EVL);
8698     Tmp2 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
8699                        Mask, EVL);
8700     Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
8701                        DAG.getConstant(0xFF00, dl, VT), Mask, EVL);
8702     Tmp1 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(24, dl, SHVT),
8703                        Mask, EVL);
8704     Tmp4 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp3, Mask, EVL);
8705     Tmp2 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp1, Mask, EVL);
8706     return DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp2, Mask, EVL);
8707   case MVT::i64:
8708     Tmp8 = DAG.getNode(ISD::VP_SHL, dl, VT, Op, DAG.getConstant(56, dl, SHVT),
8709                        Mask, EVL);
8710     Tmp7 = DAG.getNode(ISD::VP_AND, dl, VT, Op,
8711                        DAG.getConstant(255ULL << 8, dl, VT), Mask, EVL);
8712     Tmp7 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp7, DAG.getConstant(40, dl, SHVT),
8713                        Mask, EVL);
8714     Tmp6 = DAG.getNode(ISD::VP_AND, dl, VT, Op,
8715                        DAG.getConstant(255ULL << 16, dl, VT), Mask, EVL);
8716     Tmp6 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp6, DAG.getConstant(24, dl, SHVT),
8717                        Mask, EVL);
8718     Tmp5 = DAG.getNode(ISD::VP_AND, dl, VT, Op,
8719                        DAG.getConstant(255ULL << 24, dl, VT), Mask, EVL);
8720     Tmp5 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp5, DAG.getConstant(8, dl, SHVT),
8721                        Mask, EVL);
8722     Tmp4 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
8723                        Mask, EVL);
8724     Tmp4 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp4,
8725                        DAG.getConstant(255ULL << 24, dl, VT), Mask, EVL);
8726     Tmp3 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(24, dl, SHVT),
8727                        Mask, EVL);
8728     Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp3,
8729                        DAG.getConstant(255ULL << 16, dl, VT), Mask, EVL);
8730     Tmp2 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(40, dl, SHVT),
8731                        Mask, EVL);
8732     Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
8733                        DAG.getConstant(255ULL << 8, dl, VT), Mask, EVL);
8734     Tmp1 = DAG.getNode(ISD::VP_LSHR, dl, VT, Op, DAG.getConstant(56, dl, SHVT),
8735                        Mask, EVL);
8736     Tmp8 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp8, Tmp7, Mask, EVL);
8737     Tmp6 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp6, Tmp5, Mask, EVL);
8738     Tmp4 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp3, Mask, EVL);
8739     Tmp2 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp1, Mask, EVL);
8740     Tmp8 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp8, Tmp6, Mask, EVL);
8741     Tmp4 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp2, Mask, EVL);
8742     return DAG.getNode(ISD::VP_OR, dl, VT, Tmp8, Tmp4, Mask, EVL);
8743   }
8744 }
8745 
expandBITREVERSE(SDNode * N,SelectionDAG & DAG) const8746 SDValue TargetLowering::expandBITREVERSE(SDNode *N, SelectionDAG &DAG) const {
8747   SDLoc dl(N);
8748   EVT VT = N->getValueType(0);
8749   SDValue Op = N->getOperand(0);
8750   EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
8751   unsigned Sz = VT.getScalarSizeInBits();
8752 
8753   SDValue Tmp, Tmp2, Tmp3;
8754 
8755   // If we can, perform BSWAP first and then the mask+swap the i4, then i2
8756   // and finally the i1 pairs.
8757   // TODO: We can easily support i4/i2 legal types if any target ever does.
8758   if (Sz >= 8 && isPowerOf2_32(Sz)) {
8759     // Create the masks - repeating the pattern every byte.
8760     APInt Mask4 = APInt::getSplat(Sz, APInt(8, 0x0F));
8761     APInt Mask2 = APInt::getSplat(Sz, APInt(8, 0x33));
8762     APInt Mask1 = APInt::getSplat(Sz, APInt(8, 0x55));
8763 
8764     // BSWAP if the type is wider than a single byte.
8765     Tmp = (Sz > 8 ? DAG.getNode(ISD::BSWAP, dl, VT, Op) : Op);
8766 
8767     // swap i4: ((V >> 4) & 0x0F) | ((V & 0x0F) << 4)
8768     Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp, DAG.getConstant(4, dl, SHVT));
8769     Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Mask4, dl, VT));
8770     Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(Mask4, dl, VT));
8771     Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(4, dl, SHVT));
8772     Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
8773 
8774     // swap i2: ((V >> 2) & 0x33) | ((V & 0x33) << 2)
8775     Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp, DAG.getConstant(2, dl, SHVT));
8776     Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Mask2, dl, VT));
8777     Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(Mask2, dl, VT));
8778     Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(2, dl, SHVT));
8779     Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
8780 
8781     // swap i1: ((V >> 1) & 0x55) | ((V & 0x55) << 1)
8782     Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp, DAG.getConstant(1, dl, SHVT));
8783     Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Mask1, dl, VT));
8784     Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(Mask1, dl, VT));
8785     Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(1, dl, SHVT));
8786     Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
8787     return Tmp;
8788   }
8789 
8790   Tmp = DAG.getConstant(0, dl, VT);
8791   for (unsigned I = 0, J = Sz-1; I < Sz; ++I, --J) {
8792     if (I < J)
8793       Tmp2 =
8794           DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(J - I, dl, SHVT));
8795     else
8796       Tmp2 =
8797           DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(I - J, dl, SHVT));
8798 
8799     APInt Shift(Sz, 1);
8800     Shift <<= J;
8801     Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Shift, dl, VT));
8802     Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp, Tmp2);
8803   }
8804 
8805   return Tmp;
8806 }
8807 
expandVPBITREVERSE(SDNode * N,SelectionDAG & DAG) const8808 SDValue TargetLowering::expandVPBITREVERSE(SDNode *N, SelectionDAG &DAG) const {
8809   assert(N->getOpcode() == ISD::VP_BITREVERSE);
8810 
8811   SDLoc dl(N);
8812   EVT VT = N->getValueType(0);
8813   SDValue Op = N->getOperand(0);
8814   SDValue Mask = N->getOperand(1);
8815   SDValue EVL = N->getOperand(2);
8816   EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
8817   unsigned Sz = VT.getScalarSizeInBits();
8818 
8819   SDValue Tmp, Tmp2, Tmp3;
8820 
8821   // If we can, perform BSWAP first and then the mask+swap the i4, then i2
8822   // and finally the i1 pairs.
8823   // TODO: We can easily support i4/i2 legal types if any target ever does.
8824   if (Sz >= 8 && isPowerOf2_32(Sz)) {
8825     // Create the masks - repeating the pattern every byte.
8826     APInt Mask4 = APInt::getSplat(Sz, APInt(8, 0x0F));
8827     APInt Mask2 = APInt::getSplat(Sz, APInt(8, 0x33));
8828     APInt Mask1 = APInt::getSplat(Sz, APInt(8, 0x55));
8829 
8830     // BSWAP if the type is wider than a single byte.
8831     Tmp = (Sz > 8 ? DAG.getNode(ISD::VP_BSWAP, dl, VT, Op, Mask, EVL) : Op);
8832 
8833     // swap i4: ((V >> 4) & 0x0F) | ((V & 0x0F) << 4)
8834     Tmp2 = DAG.getNode(ISD::VP_LSHR, dl, VT, Tmp, DAG.getConstant(4, dl, SHVT),
8835                        Mask, EVL);
8836     Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
8837                        DAG.getConstant(Mask4, dl, VT), Mask, EVL);
8838     Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp, DAG.getConstant(Mask4, dl, VT),
8839                        Mask, EVL);
8840     Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(4, dl, SHVT),
8841                        Mask, EVL);
8842     Tmp = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp3, Mask, EVL);
8843 
8844     // swap i2: ((V >> 2) & 0x33) | ((V & 0x33) << 2)
8845     Tmp2 = DAG.getNode(ISD::VP_LSHR, dl, VT, Tmp, DAG.getConstant(2, dl, SHVT),
8846                        Mask, EVL);
8847     Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
8848                        DAG.getConstant(Mask2, dl, VT), Mask, EVL);
8849     Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp, DAG.getConstant(Mask2, dl, VT),
8850                        Mask, EVL);
8851     Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(2, dl, SHVT),
8852                        Mask, EVL);
8853     Tmp = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp3, Mask, EVL);
8854 
8855     // swap i1: ((V >> 1) & 0x55) | ((V & 0x55) << 1)
8856     Tmp2 = DAG.getNode(ISD::VP_LSHR, dl, VT, Tmp, DAG.getConstant(1, dl, SHVT),
8857                        Mask, EVL);
8858     Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
8859                        DAG.getConstant(Mask1, dl, VT), Mask, EVL);
8860     Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp, DAG.getConstant(Mask1, dl, VT),
8861                        Mask, EVL);
8862     Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(1, dl, SHVT),
8863                        Mask, EVL);
8864     Tmp = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp3, Mask, EVL);
8865     return Tmp;
8866   }
8867   return SDValue();
8868 }
8869 
8870 std::pair<SDValue, SDValue>
scalarizeVectorLoad(LoadSDNode * LD,SelectionDAG & DAG) const8871 TargetLowering::scalarizeVectorLoad(LoadSDNode *LD,
8872                                     SelectionDAG &DAG) const {
8873   SDLoc SL(LD);
8874   SDValue Chain = LD->getChain();
8875   SDValue BasePTR = LD->getBasePtr();
8876   EVT SrcVT = LD->getMemoryVT();
8877   EVT DstVT = LD->getValueType(0);
8878   ISD::LoadExtType ExtType = LD->getExtensionType();
8879 
8880   if (SrcVT.isScalableVector())
8881     report_fatal_error("Cannot scalarize scalable vector loads");
8882 
8883   unsigned NumElem = SrcVT.getVectorNumElements();
8884 
8885   EVT SrcEltVT = SrcVT.getScalarType();
8886   EVT DstEltVT = DstVT.getScalarType();
8887 
8888   // A vector must always be stored in memory as-is, i.e. without any padding
8889   // between the elements, since various code depend on it, e.g. in the
8890   // handling of a bitcast of a vector type to int, which may be done with a
8891   // vector store followed by an integer load. A vector that does not have
8892   // elements that are byte-sized must therefore be stored as an integer
8893   // built out of the extracted vector elements.
8894   if (!SrcEltVT.isByteSized()) {
8895     unsigned NumLoadBits = SrcVT.getStoreSizeInBits();
8896     EVT LoadVT = EVT::getIntegerVT(*DAG.getContext(), NumLoadBits);
8897 
8898     unsigned NumSrcBits = SrcVT.getSizeInBits();
8899     EVT SrcIntVT = EVT::getIntegerVT(*DAG.getContext(), NumSrcBits);
8900 
8901     unsigned SrcEltBits = SrcEltVT.getSizeInBits();
8902     SDValue SrcEltBitMask = DAG.getConstant(
8903         APInt::getLowBitsSet(NumLoadBits, SrcEltBits), SL, LoadVT);
8904 
8905     // Load the whole vector and avoid masking off the top bits as it makes
8906     // the codegen worse.
8907     SDValue Load =
8908         DAG.getExtLoad(ISD::EXTLOAD, SL, LoadVT, Chain, BasePTR,
8909                        LD->getPointerInfo(), SrcIntVT, LD->getOriginalAlign(),
8910                        LD->getMemOperand()->getFlags(), LD->getAAInfo());
8911 
8912     SmallVector<SDValue, 8> Vals;
8913     for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
8914       unsigned ShiftIntoIdx =
8915           (DAG.getDataLayout().isBigEndian() ? (NumElem - 1) - Idx : Idx);
8916       SDValue ShiftAmount =
8917           DAG.getShiftAmountConstant(ShiftIntoIdx * SrcEltVT.getSizeInBits(),
8918                                      LoadVT, SL, /*LegalTypes=*/false);
8919       SDValue ShiftedElt = DAG.getNode(ISD::SRL, SL, LoadVT, Load, ShiftAmount);
8920       SDValue Elt =
8921           DAG.getNode(ISD::AND, SL, LoadVT, ShiftedElt, SrcEltBitMask);
8922       SDValue Scalar = DAG.getNode(ISD::TRUNCATE, SL, SrcEltVT, Elt);
8923 
8924       if (ExtType != ISD::NON_EXTLOAD) {
8925         unsigned ExtendOp = ISD::getExtForLoadExtType(false, ExtType);
8926         Scalar = DAG.getNode(ExtendOp, SL, DstEltVT, Scalar);
8927       }
8928 
8929       Vals.push_back(Scalar);
8930     }
8931 
8932     SDValue Value = DAG.getBuildVector(DstVT, SL, Vals);
8933     return std::make_pair(Value, Load.getValue(1));
8934   }
8935 
8936   unsigned Stride = SrcEltVT.getSizeInBits() / 8;
8937   assert(SrcEltVT.isByteSized());
8938 
8939   SmallVector<SDValue, 8> Vals;
8940   SmallVector<SDValue, 8> LoadChains;
8941 
8942   for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
8943     SDValue ScalarLoad =
8944         DAG.getExtLoad(ExtType, SL, DstEltVT, Chain, BasePTR,
8945                        LD->getPointerInfo().getWithOffset(Idx * Stride),
8946                        SrcEltVT, LD->getOriginalAlign(),
8947                        LD->getMemOperand()->getFlags(), LD->getAAInfo());
8948 
8949     BasePTR = DAG.getObjectPtrOffset(SL, BasePTR, TypeSize::Fixed(Stride));
8950 
8951     Vals.push_back(ScalarLoad.getValue(0));
8952     LoadChains.push_back(ScalarLoad.getValue(1));
8953   }
8954 
8955   SDValue NewChain = DAG.getNode(ISD::TokenFactor, SL, MVT::Other, LoadChains);
8956   SDValue Value = DAG.getBuildVector(DstVT, SL, Vals);
8957 
8958   return std::make_pair(Value, NewChain);
8959 }
8960 
scalarizeVectorStore(StoreSDNode * ST,SelectionDAG & DAG) const8961 SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST,
8962                                              SelectionDAG &DAG) const {
8963   SDLoc SL(ST);
8964 
8965   SDValue Chain = ST->getChain();
8966   SDValue BasePtr = ST->getBasePtr();
8967   SDValue Value = ST->getValue();
8968   EVT StVT = ST->getMemoryVT();
8969 
8970   if (StVT.isScalableVector())
8971     report_fatal_error("Cannot scalarize scalable vector stores");
8972 
8973   // The type of the data we want to save
8974   EVT RegVT = Value.getValueType();
8975   EVT RegSclVT = RegVT.getScalarType();
8976 
8977   // The type of data as saved in memory.
8978   EVT MemSclVT = StVT.getScalarType();
8979 
8980   unsigned NumElem = StVT.getVectorNumElements();
8981 
8982   // A vector must always be stored in memory as-is, i.e. without any padding
8983   // between the elements, since various code depend on it, e.g. in the
8984   // handling of a bitcast of a vector type to int, which may be done with a
8985   // vector store followed by an integer load. A vector that does not have
8986   // elements that are byte-sized must therefore be stored as an integer
8987   // built out of the extracted vector elements.
8988   if (!MemSclVT.isByteSized()) {
8989     unsigned NumBits = StVT.getSizeInBits();
8990     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), NumBits);
8991 
8992     SDValue CurrVal = DAG.getConstant(0, SL, IntVT);
8993 
8994     for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
8995       SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, RegSclVT, Value,
8996                                 DAG.getVectorIdxConstant(Idx, SL));
8997       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MemSclVT, Elt);
8998       SDValue ExtElt = DAG.getNode(ISD::ZERO_EXTEND, SL, IntVT, Trunc);
8999       unsigned ShiftIntoIdx =
9000           (DAG.getDataLayout().isBigEndian() ? (NumElem - 1) - Idx : Idx);
9001       SDValue ShiftAmount =
9002           DAG.getConstant(ShiftIntoIdx * MemSclVT.getSizeInBits(), SL, IntVT);
9003       SDValue ShiftedElt =
9004           DAG.getNode(ISD::SHL, SL, IntVT, ExtElt, ShiftAmount);
9005       CurrVal = DAG.getNode(ISD::OR, SL, IntVT, CurrVal, ShiftedElt);
9006     }
9007 
9008     return DAG.getStore(Chain, SL, CurrVal, BasePtr, ST->getPointerInfo(),
9009                         ST->getOriginalAlign(), ST->getMemOperand()->getFlags(),
9010                         ST->getAAInfo());
9011   }
9012 
9013   // Store Stride in bytes
9014   unsigned Stride = MemSclVT.getSizeInBits() / 8;
9015   assert(Stride && "Zero stride!");
9016   // Extract each of the elements from the original vector and save them into
9017   // memory individually.
9018   SmallVector<SDValue, 8> Stores;
9019   for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
9020     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, RegSclVT, Value,
9021                               DAG.getVectorIdxConstant(Idx, SL));
9022 
9023     SDValue Ptr =
9024         DAG.getObjectPtrOffset(SL, BasePtr, TypeSize::Fixed(Idx * Stride));
9025 
9026     // This scalar TruncStore may be illegal, but we legalize it later.
9027     SDValue Store = DAG.getTruncStore(
9028         Chain, SL, Elt, Ptr, ST->getPointerInfo().getWithOffset(Idx * Stride),
9029         MemSclVT, ST->getOriginalAlign(), ST->getMemOperand()->getFlags(),
9030         ST->getAAInfo());
9031 
9032     Stores.push_back(Store);
9033   }
9034 
9035   return DAG.getNode(ISD::TokenFactor, SL, MVT::Other, Stores);
9036 }
9037 
9038 std::pair<SDValue, SDValue>
expandUnalignedLoad(LoadSDNode * LD,SelectionDAG & DAG) const9039 TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const {
9040   assert(LD->getAddressingMode() == ISD::UNINDEXED &&
9041          "unaligned indexed loads not implemented!");
9042   SDValue Chain = LD->getChain();
9043   SDValue Ptr = LD->getBasePtr();
9044   EVT VT = LD->getValueType(0);
9045   EVT LoadedVT = LD->getMemoryVT();
9046   SDLoc dl(LD);
9047   auto &MF = DAG.getMachineFunction();
9048 
9049   if (VT.isFloatingPoint() || VT.isVector()) {
9050     EVT intVT = EVT::getIntegerVT(*DAG.getContext(), LoadedVT.getSizeInBits());
9051     if (isTypeLegal(intVT) && isTypeLegal(LoadedVT)) {
9052       if (!isOperationLegalOrCustom(ISD::LOAD, intVT) &&
9053           LoadedVT.isVector()) {
9054         // Scalarize the load and let the individual components be handled.
9055         return scalarizeVectorLoad(LD, DAG);
9056       }
9057 
9058       // Expand to a (misaligned) integer load of the same size,
9059       // then bitconvert to floating point or vector.
9060       SDValue newLoad = DAG.getLoad(intVT, dl, Chain, Ptr,
9061                                     LD->getMemOperand());
9062       SDValue Result = DAG.getNode(ISD::BITCAST, dl, LoadedVT, newLoad);
9063       if (LoadedVT != VT)
9064         Result = DAG.getNode(VT.isFloatingPoint() ? ISD::FP_EXTEND :
9065                              ISD::ANY_EXTEND, dl, VT, Result);
9066 
9067       return std::make_pair(Result, newLoad.getValue(1));
9068     }
9069 
9070     // Copy the value to a (aligned) stack slot using (unaligned) integer
9071     // loads and stores, then do a (aligned) load from the stack slot.
9072     MVT RegVT = getRegisterType(*DAG.getContext(), intVT);
9073     unsigned LoadedBytes = LoadedVT.getStoreSize();
9074     unsigned RegBytes = RegVT.getSizeInBits() / 8;
9075     unsigned NumRegs = (LoadedBytes + RegBytes - 1) / RegBytes;
9076 
9077     // Make sure the stack slot is also aligned for the register type.
9078     SDValue StackBase = DAG.CreateStackTemporary(LoadedVT, RegVT);
9079     auto FrameIndex = cast<FrameIndexSDNode>(StackBase.getNode())->getIndex();
9080     SmallVector<SDValue, 8> Stores;
9081     SDValue StackPtr = StackBase;
9082     unsigned Offset = 0;
9083 
9084     EVT PtrVT = Ptr.getValueType();
9085     EVT StackPtrVT = StackPtr.getValueType();
9086 
9087     SDValue PtrIncrement = DAG.getConstant(RegBytes, dl, PtrVT);
9088     SDValue StackPtrIncrement = DAG.getConstant(RegBytes, dl, StackPtrVT);
9089 
9090     // Do all but one copies using the full register width.
9091     for (unsigned i = 1; i < NumRegs; i++) {
9092       // Load one integer register's worth from the original location.
9093       SDValue Load = DAG.getLoad(
9094           RegVT, dl, Chain, Ptr, LD->getPointerInfo().getWithOffset(Offset),
9095           LD->getOriginalAlign(), LD->getMemOperand()->getFlags(),
9096           LD->getAAInfo());
9097       // Follow the load with a store to the stack slot.  Remember the store.
9098       Stores.push_back(DAG.getStore(
9099           Load.getValue(1), dl, Load, StackPtr,
9100           MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset)));
9101       // Increment the pointers.
9102       Offset += RegBytes;
9103 
9104       Ptr = DAG.getObjectPtrOffset(dl, Ptr, PtrIncrement);
9105       StackPtr = DAG.getObjectPtrOffset(dl, StackPtr, StackPtrIncrement);
9106     }
9107 
9108     // The last copy may be partial.  Do an extending load.
9109     EVT MemVT = EVT::getIntegerVT(*DAG.getContext(),
9110                                   8 * (LoadedBytes - Offset));
9111     SDValue Load =
9112         DAG.getExtLoad(ISD::EXTLOAD, dl, RegVT, Chain, Ptr,
9113                        LD->getPointerInfo().getWithOffset(Offset), MemVT,
9114                        LD->getOriginalAlign(), LD->getMemOperand()->getFlags(),
9115                        LD->getAAInfo());
9116     // Follow the load with a store to the stack slot.  Remember the store.
9117     // On big-endian machines this requires a truncating store to ensure
9118     // that the bits end up in the right place.
9119     Stores.push_back(DAG.getTruncStore(
9120         Load.getValue(1), dl, Load, StackPtr,
9121         MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset), MemVT));
9122 
9123     // The order of the stores doesn't matter - say it with a TokenFactor.
9124     SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Stores);
9125 
9126     // Finally, perform the original load only redirected to the stack slot.
9127     Load = DAG.getExtLoad(LD->getExtensionType(), dl, VT, TF, StackBase,
9128                           MachinePointerInfo::getFixedStack(MF, FrameIndex, 0),
9129                           LoadedVT);
9130 
9131     // Callers expect a MERGE_VALUES node.
9132     return std::make_pair(Load, TF);
9133   }
9134 
9135   assert(LoadedVT.isInteger() && !LoadedVT.isVector() &&
9136          "Unaligned load of unsupported type.");
9137 
9138   // Compute the new VT that is half the size of the old one.  This is an
9139   // integer MVT.
9140   unsigned NumBits = LoadedVT.getSizeInBits();
9141   EVT NewLoadedVT;
9142   NewLoadedVT = EVT::getIntegerVT(*DAG.getContext(), NumBits/2);
9143   NumBits >>= 1;
9144 
9145   Align Alignment = LD->getOriginalAlign();
9146   unsigned IncrementSize = NumBits / 8;
9147   ISD::LoadExtType HiExtType = LD->getExtensionType();
9148 
9149   // If the original load is NON_EXTLOAD, the hi part load must be ZEXTLOAD.
9150   if (HiExtType == ISD::NON_EXTLOAD)
9151     HiExtType = ISD::ZEXTLOAD;
9152 
9153   // Load the value in two parts
9154   SDValue Lo, Hi;
9155   if (DAG.getDataLayout().isLittleEndian()) {
9156     Lo = DAG.getExtLoad(ISD::ZEXTLOAD, dl, VT, Chain, Ptr, LD->getPointerInfo(),
9157                         NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
9158                         LD->getAAInfo());
9159 
9160     Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::Fixed(IncrementSize));
9161     Hi = DAG.getExtLoad(HiExtType, dl, VT, Chain, Ptr,
9162                         LD->getPointerInfo().getWithOffset(IncrementSize),
9163                         NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
9164                         LD->getAAInfo());
9165   } else {
9166     Hi = DAG.getExtLoad(HiExtType, dl, VT, Chain, Ptr, LD->getPointerInfo(),
9167                         NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
9168                         LD->getAAInfo());
9169 
9170     Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::Fixed(IncrementSize));
9171     Lo = DAG.getExtLoad(ISD::ZEXTLOAD, dl, VT, Chain, Ptr,
9172                         LD->getPointerInfo().getWithOffset(IncrementSize),
9173                         NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
9174                         LD->getAAInfo());
9175   }
9176 
9177   // aggregate the two parts
9178   SDValue ShiftAmount =
9179       DAG.getConstant(NumBits, dl, getShiftAmountTy(Hi.getValueType(),
9180                                                     DAG.getDataLayout()));
9181   SDValue Result = DAG.getNode(ISD::SHL, dl, VT, Hi, ShiftAmount);
9182   Result = DAG.getNode(ISD::OR, dl, VT, Result, Lo);
9183 
9184   SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1),
9185                              Hi.getValue(1));
9186 
9187   return std::make_pair(Result, TF);
9188 }
9189 
expandUnalignedStore(StoreSDNode * ST,SelectionDAG & DAG) const9190 SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST,
9191                                              SelectionDAG &DAG) const {
9192   assert(ST->getAddressingMode() == ISD::UNINDEXED &&
9193          "unaligned indexed stores not implemented!");
9194   SDValue Chain = ST->getChain();
9195   SDValue Ptr = ST->getBasePtr();
9196   SDValue Val = ST->getValue();
9197   EVT VT = Val.getValueType();
9198   Align Alignment = ST->getOriginalAlign();
9199   auto &MF = DAG.getMachineFunction();
9200   EVT StoreMemVT = ST->getMemoryVT();
9201 
9202   SDLoc dl(ST);
9203   if (StoreMemVT.isFloatingPoint() || StoreMemVT.isVector()) {
9204     EVT intVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
9205     if (isTypeLegal(intVT)) {
9206       if (!isOperationLegalOrCustom(ISD::STORE, intVT) &&
9207           StoreMemVT.isVector()) {
9208         // Scalarize the store and let the individual components be handled.
9209         SDValue Result = scalarizeVectorStore(ST, DAG);
9210         return Result;
9211       }
9212       // Expand to a bitconvert of the value to the integer type of the
9213       // same size, then a (misaligned) int store.
9214       // FIXME: Does not handle truncating floating point stores!
9215       SDValue Result = DAG.getNode(ISD::BITCAST, dl, intVT, Val);
9216       Result = DAG.getStore(Chain, dl, Result, Ptr, ST->getPointerInfo(),
9217                             Alignment, ST->getMemOperand()->getFlags());
9218       return Result;
9219     }
9220     // Do a (aligned) store to a stack slot, then copy from the stack slot
9221     // to the final destination using (unaligned) integer loads and stores.
9222     MVT RegVT = getRegisterType(
9223         *DAG.getContext(),
9224         EVT::getIntegerVT(*DAG.getContext(), StoreMemVT.getSizeInBits()));
9225     EVT PtrVT = Ptr.getValueType();
9226     unsigned StoredBytes = StoreMemVT.getStoreSize();
9227     unsigned RegBytes = RegVT.getSizeInBits() / 8;
9228     unsigned NumRegs = (StoredBytes + RegBytes - 1) / RegBytes;
9229 
9230     // Make sure the stack slot is also aligned for the register type.
9231     SDValue StackPtr = DAG.CreateStackTemporary(StoreMemVT, RegVT);
9232     auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
9233 
9234     // Perform the original store, only redirected to the stack slot.
9235     SDValue Store = DAG.getTruncStore(
9236         Chain, dl, Val, StackPtr,
9237         MachinePointerInfo::getFixedStack(MF, FrameIndex, 0), StoreMemVT);
9238 
9239     EVT StackPtrVT = StackPtr.getValueType();
9240 
9241     SDValue PtrIncrement = DAG.getConstant(RegBytes, dl, PtrVT);
9242     SDValue StackPtrIncrement = DAG.getConstant(RegBytes, dl, StackPtrVT);
9243     SmallVector<SDValue, 8> Stores;
9244     unsigned Offset = 0;
9245 
9246     // Do all but one copies using the full register width.
9247     for (unsigned i = 1; i < NumRegs; i++) {
9248       // Load one integer register's worth from the stack slot.
9249       SDValue Load = DAG.getLoad(
9250           RegVT, dl, Store, StackPtr,
9251           MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset));
9252       // Store it to the final location.  Remember the store.
9253       Stores.push_back(DAG.getStore(Load.getValue(1), dl, Load, Ptr,
9254                                     ST->getPointerInfo().getWithOffset(Offset),
9255                                     ST->getOriginalAlign(),
9256                                     ST->getMemOperand()->getFlags()));
9257       // Increment the pointers.
9258       Offset += RegBytes;
9259       StackPtr = DAG.getObjectPtrOffset(dl, StackPtr, StackPtrIncrement);
9260       Ptr = DAG.getObjectPtrOffset(dl, Ptr, PtrIncrement);
9261     }
9262 
9263     // The last store may be partial.  Do a truncating store.  On big-endian
9264     // machines this requires an extending load from the stack slot to ensure
9265     // that the bits are in the right place.
9266     EVT LoadMemVT =
9267         EVT::getIntegerVT(*DAG.getContext(), 8 * (StoredBytes - Offset));
9268 
9269     // Load from the stack slot.
9270     SDValue Load = DAG.getExtLoad(
9271         ISD::EXTLOAD, dl, RegVT, Store, StackPtr,
9272         MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset), LoadMemVT);
9273 
9274     Stores.push_back(
9275         DAG.getTruncStore(Load.getValue(1), dl, Load, Ptr,
9276                           ST->getPointerInfo().getWithOffset(Offset), LoadMemVT,
9277                           ST->getOriginalAlign(),
9278                           ST->getMemOperand()->getFlags(), ST->getAAInfo()));
9279     // The order of the stores doesn't matter - say it with a TokenFactor.
9280     SDValue Result = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Stores);
9281     return Result;
9282   }
9283 
9284   assert(StoreMemVT.isInteger() && !StoreMemVT.isVector() &&
9285          "Unaligned store of unknown type.");
9286   // Get the half-size VT
9287   EVT NewStoredVT = StoreMemVT.getHalfSizedIntegerVT(*DAG.getContext());
9288   unsigned NumBits = NewStoredVT.getFixedSizeInBits();
9289   unsigned IncrementSize = NumBits / 8;
9290 
9291   // Divide the stored value in two parts.
9292   SDValue ShiftAmount = DAG.getConstant(
9293       NumBits, dl, getShiftAmountTy(Val.getValueType(), DAG.getDataLayout()));
9294   SDValue Lo = Val;
9295   SDValue Hi = DAG.getNode(ISD::SRL, dl, VT, Val, ShiftAmount);
9296 
9297   // Store the two parts
9298   SDValue Store1, Store2;
9299   Store1 = DAG.getTruncStore(Chain, dl,
9300                              DAG.getDataLayout().isLittleEndian() ? Lo : Hi,
9301                              Ptr, ST->getPointerInfo(), NewStoredVT, Alignment,
9302                              ST->getMemOperand()->getFlags());
9303 
9304   Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::Fixed(IncrementSize));
9305   Store2 = DAG.getTruncStore(
9306       Chain, dl, DAG.getDataLayout().isLittleEndian() ? Hi : Lo, Ptr,
9307       ST->getPointerInfo().getWithOffset(IncrementSize), NewStoredVT, Alignment,
9308       ST->getMemOperand()->getFlags(), ST->getAAInfo());
9309 
9310   SDValue Result =
9311       DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store1, Store2);
9312   return Result;
9313 }
9314 
9315 SDValue
IncrementMemoryAddress(SDValue Addr,SDValue Mask,const SDLoc & DL,EVT DataVT,SelectionDAG & DAG,bool IsCompressedMemory) const9316 TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask,
9317                                        const SDLoc &DL, EVT DataVT,
9318                                        SelectionDAG &DAG,
9319                                        bool IsCompressedMemory) const {
9320   SDValue Increment;
9321   EVT AddrVT = Addr.getValueType();
9322   EVT MaskVT = Mask.getValueType();
9323   assert(DataVT.getVectorElementCount() == MaskVT.getVectorElementCount() &&
9324          "Incompatible types of Data and Mask");
9325   if (IsCompressedMemory) {
9326     if (DataVT.isScalableVector())
9327       report_fatal_error(
9328           "Cannot currently handle compressed memory with scalable vectors");
9329     // Incrementing the pointer according to number of '1's in the mask.
9330     EVT MaskIntVT = EVT::getIntegerVT(*DAG.getContext(), MaskVT.getSizeInBits());
9331     SDValue MaskInIntReg = DAG.getBitcast(MaskIntVT, Mask);
9332     if (MaskIntVT.getSizeInBits() < 32) {
9333       MaskInIntReg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskInIntReg);
9334       MaskIntVT = MVT::i32;
9335     }
9336 
9337     // Count '1's with POPCNT.
9338     Increment = DAG.getNode(ISD::CTPOP, DL, MaskIntVT, MaskInIntReg);
9339     Increment = DAG.getZExtOrTrunc(Increment, DL, AddrVT);
9340     // Scale is an element size in bytes.
9341     SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL,
9342                                     AddrVT);
9343     Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale);
9344   } else if (DataVT.isScalableVector()) {
9345     Increment = DAG.getVScale(DL, AddrVT,
9346                               APInt(AddrVT.getFixedSizeInBits(),
9347                                     DataVT.getStoreSize().getKnownMinValue()));
9348   } else
9349     Increment = DAG.getConstant(DataVT.getStoreSize(), DL, AddrVT);
9350 
9351   return DAG.getNode(ISD::ADD, DL, AddrVT, Addr, Increment);
9352 }
9353 
clampDynamicVectorIndex(SelectionDAG & DAG,SDValue Idx,EVT VecVT,const SDLoc & dl,ElementCount SubEC)9354 static SDValue clampDynamicVectorIndex(SelectionDAG &DAG, SDValue Idx,
9355                                        EVT VecVT, const SDLoc &dl,
9356                                        ElementCount SubEC) {
9357   assert(!(SubEC.isScalable() && VecVT.isFixedLengthVector()) &&
9358          "Cannot index a scalable vector within a fixed-width vector");
9359 
9360   unsigned NElts = VecVT.getVectorMinNumElements();
9361   unsigned NumSubElts = SubEC.getKnownMinValue();
9362   EVT IdxVT = Idx.getValueType();
9363 
9364   if (VecVT.isScalableVector() && !SubEC.isScalable()) {
9365     // If this is a constant index and we know the value plus the number of the
9366     // elements in the subvector minus one is less than the minimum number of
9367     // elements then it's safe to return Idx.
9368     if (auto *IdxCst = dyn_cast<ConstantSDNode>(Idx))
9369       if (IdxCst->getZExtValue() + (NumSubElts - 1) < NElts)
9370         return Idx;
9371     SDValue VS =
9372         DAG.getVScale(dl, IdxVT, APInt(IdxVT.getFixedSizeInBits(), NElts));
9373     unsigned SubOpcode = NumSubElts <= NElts ? ISD::SUB : ISD::USUBSAT;
9374     SDValue Sub = DAG.getNode(SubOpcode, dl, IdxVT, VS,
9375                               DAG.getConstant(NumSubElts, dl, IdxVT));
9376     return DAG.getNode(ISD::UMIN, dl, IdxVT, Idx, Sub);
9377   }
9378   if (isPowerOf2_32(NElts) && NumSubElts == 1) {
9379     APInt Imm = APInt::getLowBitsSet(IdxVT.getSizeInBits(), Log2_32(NElts));
9380     return DAG.getNode(ISD::AND, dl, IdxVT, Idx,
9381                        DAG.getConstant(Imm, dl, IdxVT));
9382   }
9383   unsigned MaxIndex = NumSubElts < NElts ? NElts - NumSubElts : 0;
9384   return DAG.getNode(ISD::UMIN, dl, IdxVT, Idx,
9385                      DAG.getConstant(MaxIndex, dl, IdxVT));
9386 }
9387 
getVectorElementPointer(SelectionDAG & DAG,SDValue VecPtr,EVT VecVT,SDValue Index) const9388 SDValue TargetLowering::getVectorElementPointer(SelectionDAG &DAG,
9389                                                 SDValue VecPtr, EVT VecVT,
9390                                                 SDValue Index) const {
9391   return getVectorSubVecPointer(
9392       DAG, VecPtr, VecVT,
9393       EVT::getVectorVT(*DAG.getContext(), VecVT.getVectorElementType(), 1),
9394       Index);
9395 }
9396 
getVectorSubVecPointer(SelectionDAG & DAG,SDValue VecPtr,EVT VecVT,EVT SubVecVT,SDValue Index) const9397 SDValue TargetLowering::getVectorSubVecPointer(SelectionDAG &DAG,
9398                                                SDValue VecPtr, EVT VecVT,
9399                                                EVT SubVecVT,
9400                                                SDValue Index) const {
9401   SDLoc dl(Index);
9402   // Make sure the index type is big enough to compute in.
9403   Index = DAG.getZExtOrTrunc(Index, dl, VecPtr.getValueType());
9404 
9405   EVT EltVT = VecVT.getVectorElementType();
9406 
9407   // Calculate the element offset and add it to the pointer.
9408   unsigned EltSize = EltVT.getFixedSizeInBits() / 8; // FIXME: should be ABI size.
9409   assert(EltSize * 8 == EltVT.getFixedSizeInBits() &&
9410          "Converting bits to bytes lost precision");
9411   assert(SubVecVT.getVectorElementType() == EltVT &&
9412          "Sub-vector must be a vector with matching element type");
9413   Index = clampDynamicVectorIndex(DAG, Index, VecVT, dl,
9414                                   SubVecVT.getVectorElementCount());
9415 
9416   EVT IdxVT = Index.getValueType();
9417   if (SubVecVT.isScalableVector())
9418     Index =
9419         DAG.getNode(ISD::MUL, dl, IdxVT, Index,
9420                     DAG.getVScale(dl, IdxVT, APInt(IdxVT.getSizeInBits(), 1)));
9421 
9422   Index = DAG.getNode(ISD::MUL, dl, IdxVT, Index,
9423                       DAG.getConstant(EltSize, dl, IdxVT));
9424   return DAG.getMemBasePlusOffset(VecPtr, Index, dl);
9425 }
9426 
9427 //===----------------------------------------------------------------------===//
9428 // Implementation of Emulated TLS Model
9429 //===----------------------------------------------------------------------===//
9430 
LowerToTLSEmulatedModel(const GlobalAddressSDNode * GA,SelectionDAG & DAG) const9431 SDValue TargetLowering::LowerToTLSEmulatedModel(const GlobalAddressSDNode *GA,
9432                                                 SelectionDAG &DAG) const {
9433   // Access to address of TLS varialbe xyz is lowered to a function call:
9434   //   __emutls_get_address( address of global variable named "__emutls_v.xyz" )
9435   EVT PtrVT = getPointerTy(DAG.getDataLayout());
9436   PointerType *VoidPtrType = Type::getInt8PtrTy(*DAG.getContext());
9437   SDLoc dl(GA);
9438 
9439   ArgListTy Args;
9440   ArgListEntry Entry;
9441   std::string NameString = ("__emutls_v." + GA->getGlobal()->getName()).str();
9442   Module *VariableModule = const_cast<Module*>(GA->getGlobal()->getParent());
9443   StringRef EmuTlsVarName(NameString);
9444   GlobalVariable *EmuTlsVar = VariableModule->getNamedGlobal(EmuTlsVarName);
9445   assert(EmuTlsVar && "Cannot find EmuTlsVar ");
9446   Entry.Node = DAG.getGlobalAddress(EmuTlsVar, dl, PtrVT);
9447   Entry.Ty = VoidPtrType;
9448   Args.push_back(Entry);
9449 
9450   SDValue EmuTlsGetAddr = DAG.getExternalSymbol("__emutls_get_address", PtrVT);
9451 
9452   TargetLowering::CallLoweringInfo CLI(DAG);
9453   CLI.setDebugLoc(dl).setChain(DAG.getEntryNode());
9454   CLI.setLibCallee(CallingConv::C, VoidPtrType, EmuTlsGetAddr, std::move(Args));
9455   std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
9456 
9457   // TLSADDR will be codegen'ed as call. Inform MFI that function has calls.
9458   // At last for X86 targets, maybe good for other targets too?
9459   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
9460   MFI.setAdjustsStack(true); // Is this only for X86 target?
9461   MFI.setHasCalls(true);
9462 
9463   assert((GA->getOffset() == 0) &&
9464          "Emulated TLS must have zero offset in GlobalAddressSDNode");
9465   return CallResult.first;
9466 }
9467 
lowerCmpEqZeroToCtlzSrl(SDValue Op,SelectionDAG & DAG) const9468 SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op,
9469                                                 SelectionDAG &DAG) const {
9470   assert((Op->getOpcode() == ISD::SETCC) && "Input has to be a SETCC node.");
9471   if (!isCtlzFast())
9472     return SDValue();
9473   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
9474   SDLoc dl(Op);
9475   if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
9476     if (C->isZero() && CC == ISD::SETEQ) {
9477       EVT VT = Op.getOperand(0).getValueType();
9478       SDValue Zext = Op.getOperand(0);
9479       if (VT.bitsLT(MVT::i32)) {
9480         VT = MVT::i32;
9481         Zext = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Op.getOperand(0));
9482       }
9483       unsigned Log2b = Log2_32(VT.getSizeInBits());
9484       SDValue Clz = DAG.getNode(ISD::CTLZ, dl, VT, Zext);
9485       SDValue Scc = DAG.getNode(ISD::SRL, dl, VT, Clz,
9486                                 DAG.getConstant(Log2b, dl, MVT::i32));
9487       return DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Scc);
9488     }
9489   }
9490   return SDValue();
9491 }
9492 
expandIntMINMAX(SDNode * Node,SelectionDAG & DAG) const9493 SDValue TargetLowering::expandIntMINMAX(SDNode *Node, SelectionDAG &DAG) const {
9494   SDValue Op0 = Node->getOperand(0);
9495   SDValue Op1 = Node->getOperand(1);
9496   EVT VT = Op0.getValueType();
9497   unsigned Opcode = Node->getOpcode();
9498   SDLoc DL(Node);
9499 
9500   // umin(x,y) -> sub(x,usubsat(x,y))
9501   if (Opcode == ISD::UMIN && isOperationLegal(ISD::SUB, VT) &&
9502       isOperationLegal(ISD::USUBSAT, VT)) {
9503     return DAG.getNode(ISD::SUB, DL, VT, Op0,
9504                        DAG.getNode(ISD::USUBSAT, DL, VT, Op0, Op1));
9505   }
9506 
9507   // umax(x,y) -> add(x,usubsat(y,x))
9508   if (Opcode == ISD::UMAX && isOperationLegal(ISD::ADD, VT) &&
9509       isOperationLegal(ISD::USUBSAT, VT)) {
9510     return DAG.getNode(ISD::ADD, DL, VT, Op0,
9511                        DAG.getNode(ISD::USUBSAT, DL, VT, Op1, Op0));
9512   }
9513 
9514   // Expand Y = MAX(A, B) -> Y = (A > B) ? A : B
9515   ISD::CondCode CC;
9516   switch (Opcode) {
9517   default: llvm_unreachable("How did we get here?");
9518   case ISD::SMAX: CC = ISD::SETGT; break;
9519   case ISD::SMIN: CC = ISD::SETLT; break;
9520   case ISD::UMAX: CC = ISD::SETUGT; break;
9521   case ISD::UMIN: CC = ISD::SETULT; break;
9522   }
9523 
9524   // FIXME: Should really try to split the vector in case it's legal on a
9525   // subvector.
9526   if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9527     return DAG.UnrollVectorOp(Node);
9528 
9529   EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9530   SDValue Cond = DAG.getSetCC(DL, BoolVT, Op0, Op1, CC);
9531   return DAG.getSelect(DL, VT, Cond, Op0, Op1);
9532 }
9533 
expandAddSubSat(SDNode * Node,SelectionDAG & DAG) const9534 SDValue TargetLowering::expandAddSubSat(SDNode *Node, SelectionDAG &DAG) const {
9535   unsigned Opcode = Node->getOpcode();
9536   SDValue LHS = Node->getOperand(0);
9537   SDValue RHS = Node->getOperand(1);
9538   EVT VT = LHS.getValueType();
9539   SDLoc dl(Node);
9540 
9541   assert(VT == RHS.getValueType() && "Expected operands to be the same type");
9542   assert(VT.isInteger() && "Expected operands to be integers");
9543 
9544   // usub.sat(a, b) -> umax(a, b) - b
9545   if (Opcode == ISD::USUBSAT && isOperationLegal(ISD::UMAX, VT)) {
9546     SDValue Max = DAG.getNode(ISD::UMAX, dl, VT, LHS, RHS);
9547     return DAG.getNode(ISD::SUB, dl, VT, Max, RHS);
9548   }
9549 
9550   // uadd.sat(a, b) -> umin(a, ~b) + b
9551   if (Opcode == ISD::UADDSAT && isOperationLegal(ISD::UMIN, VT)) {
9552     SDValue InvRHS = DAG.getNOT(dl, RHS, VT);
9553     SDValue Min = DAG.getNode(ISD::UMIN, dl, VT, LHS, InvRHS);
9554     return DAG.getNode(ISD::ADD, dl, VT, Min, RHS);
9555   }
9556 
9557   unsigned OverflowOp;
9558   switch (Opcode) {
9559   case ISD::SADDSAT:
9560     OverflowOp = ISD::SADDO;
9561     break;
9562   case ISD::UADDSAT:
9563     OverflowOp = ISD::UADDO;
9564     break;
9565   case ISD::SSUBSAT:
9566     OverflowOp = ISD::SSUBO;
9567     break;
9568   case ISD::USUBSAT:
9569     OverflowOp = ISD::USUBO;
9570     break;
9571   default:
9572     llvm_unreachable("Expected method to receive signed or unsigned saturation "
9573                      "addition or subtraction node.");
9574   }
9575 
9576   // FIXME: Should really try to split the vector in case it's legal on a
9577   // subvector.
9578   if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9579     return DAG.UnrollVectorOp(Node);
9580 
9581   unsigned BitWidth = LHS.getScalarValueSizeInBits();
9582   EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9583   SDValue Result = DAG.getNode(OverflowOp, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
9584   SDValue SumDiff = Result.getValue(0);
9585   SDValue Overflow = Result.getValue(1);
9586   SDValue Zero = DAG.getConstant(0, dl, VT);
9587   SDValue AllOnes = DAG.getAllOnesConstant(dl, VT);
9588 
9589   if (Opcode == ISD::UADDSAT) {
9590     if (getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) {
9591       // (LHS + RHS) | OverflowMask
9592       SDValue OverflowMask = DAG.getSExtOrTrunc(Overflow, dl, VT);
9593       return DAG.getNode(ISD::OR, dl, VT, SumDiff, OverflowMask);
9594     }
9595     // Overflow ? 0xffff.... : (LHS + RHS)
9596     return DAG.getSelect(dl, VT, Overflow, AllOnes, SumDiff);
9597   }
9598 
9599   if (Opcode == ISD::USUBSAT) {
9600     if (getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) {
9601       // (LHS - RHS) & ~OverflowMask
9602       SDValue OverflowMask = DAG.getSExtOrTrunc(Overflow, dl, VT);
9603       SDValue Not = DAG.getNOT(dl, OverflowMask, VT);
9604       return DAG.getNode(ISD::AND, dl, VT, SumDiff, Not);
9605     }
9606     // Overflow ? 0 : (LHS - RHS)
9607     return DAG.getSelect(dl, VT, Overflow, Zero, SumDiff);
9608   }
9609 
9610   // Overflow ? (SumDiff >> BW) ^ MinVal : SumDiff
9611   APInt MinVal = APInt::getSignedMinValue(BitWidth);
9612   SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
9613   SDValue Shift = DAG.getNode(ISD::SRA, dl, VT, SumDiff,
9614                               DAG.getConstant(BitWidth - 1, dl, VT));
9615   Result = DAG.getNode(ISD::XOR, dl, VT, Shift, SatMin);
9616   return DAG.getSelect(dl, VT, Overflow, Result, SumDiff);
9617 }
9618 
expandShlSat(SDNode * Node,SelectionDAG & DAG) const9619 SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
9620   unsigned Opcode = Node->getOpcode();
9621   bool IsSigned = Opcode == ISD::SSHLSAT;
9622   SDValue LHS = Node->getOperand(0);
9623   SDValue RHS = Node->getOperand(1);
9624   EVT VT = LHS.getValueType();
9625   SDLoc dl(Node);
9626 
9627   assert((Node->getOpcode() == ISD::SSHLSAT ||
9628           Node->getOpcode() == ISD::USHLSAT) &&
9629           "Expected a SHLSAT opcode");
9630   assert(VT == RHS.getValueType() && "Expected operands to be the same type");
9631   assert(VT.isInteger() && "Expected operands to be integers");
9632 
9633   if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9634     return DAG.UnrollVectorOp(Node);
9635 
9636   // If LHS != (LHS << RHS) >> RHS, we have overflow and must saturate.
9637 
9638   unsigned BW = VT.getScalarSizeInBits();
9639   EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9640   SDValue Result = DAG.getNode(ISD::SHL, dl, VT, LHS, RHS);
9641   SDValue Orig =
9642       DAG.getNode(IsSigned ? ISD::SRA : ISD::SRL, dl, VT, Result, RHS);
9643 
9644   SDValue SatVal;
9645   if (IsSigned) {
9646     SDValue SatMin = DAG.getConstant(APInt::getSignedMinValue(BW), dl, VT);
9647     SDValue SatMax = DAG.getConstant(APInt::getSignedMaxValue(BW), dl, VT);
9648     SDValue Cond =
9649         DAG.getSetCC(dl, BoolVT, LHS, DAG.getConstant(0, dl, VT), ISD::SETLT);
9650     SatVal = DAG.getSelect(dl, VT, Cond, SatMin, SatMax);
9651   } else {
9652     SatVal = DAG.getConstant(APInt::getMaxValue(BW), dl, VT);
9653   }
9654   SDValue Cond = DAG.getSetCC(dl, BoolVT, LHS, Orig, ISD::SETNE);
9655   return DAG.getSelect(dl, VT, Cond, SatVal, Result);
9656 }
9657 
9658 SDValue
expandFixedPointMul(SDNode * Node,SelectionDAG & DAG) const9659 TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
9660   assert((Node->getOpcode() == ISD::SMULFIX ||
9661           Node->getOpcode() == ISD::UMULFIX ||
9662           Node->getOpcode() == ISD::SMULFIXSAT ||
9663           Node->getOpcode() == ISD::UMULFIXSAT) &&
9664          "Expected a fixed point multiplication opcode");
9665 
9666   SDLoc dl(Node);
9667   SDValue LHS = Node->getOperand(0);
9668   SDValue RHS = Node->getOperand(1);
9669   EVT VT = LHS.getValueType();
9670   unsigned Scale = Node->getConstantOperandVal(2);
9671   bool Saturating = (Node->getOpcode() == ISD::SMULFIXSAT ||
9672                      Node->getOpcode() == ISD::UMULFIXSAT);
9673   bool Signed = (Node->getOpcode() == ISD::SMULFIX ||
9674                  Node->getOpcode() == ISD::SMULFIXSAT);
9675   EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9676   unsigned VTSize = VT.getScalarSizeInBits();
9677 
9678   if (!Scale) {
9679     // [us]mul.fix(a, b, 0) -> mul(a, b)
9680     if (!Saturating) {
9681       if (isOperationLegalOrCustom(ISD::MUL, VT))
9682         return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
9683     } else if (Signed && isOperationLegalOrCustom(ISD::SMULO, VT)) {
9684       SDValue Result =
9685           DAG.getNode(ISD::SMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
9686       SDValue Product = Result.getValue(0);
9687       SDValue Overflow = Result.getValue(1);
9688       SDValue Zero = DAG.getConstant(0, dl, VT);
9689 
9690       APInt MinVal = APInt::getSignedMinValue(VTSize);
9691       APInt MaxVal = APInt::getSignedMaxValue(VTSize);
9692       SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
9693       SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
9694       // Xor the inputs, if resulting sign bit is 0 the product will be
9695       // positive, else negative.
9696       SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
9697       SDValue ProdNeg = DAG.getSetCC(dl, BoolVT, Xor, Zero, ISD::SETLT);
9698       Result = DAG.getSelect(dl, VT, ProdNeg, SatMin, SatMax);
9699       return DAG.getSelect(dl, VT, Overflow, Result, Product);
9700     } else if (!Signed && isOperationLegalOrCustom(ISD::UMULO, VT)) {
9701       SDValue Result =
9702           DAG.getNode(ISD::UMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
9703       SDValue Product = Result.getValue(0);
9704       SDValue Overflow = Result.getValue(1);
9705 
9706       APInt MaxVal = APInt::getMaxValue(VTSize);
9707       SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
9708       return DAG.getSelect(dl, VT, Overflow, SatMax, Product);
9709     }
9710   }
9711 
9712   assert(((Signed && Scale < VTSize) || (!Signed && Scale <= VTSize)) &&
9713          "Expected scale to be less than the number of bits if signed or at "
9714          "most the number of bits if unsigned.");
9715   assert(LHS.getValueType() == RHS.getValueType() &&
9716          "Expected both operands to be the same type");
9717 
9718   // Get the upper and lower bits of the result.
9719   SDValue Lo, Hi;
9720   unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
9721   unsigned HiOp = Signed ? ISD::MULHS : ISD::MULHU;
9722   if (isOperationLegalOrCustom(LoHiOp, VT)) {
9723     SDValue Result = DAG.getNode(LoHiOp, dl, DAG.getVTList(VT, VT), LHS, RHS);
9724     Lo = Result.getValue(0);
9725     Hi = Result.getValue(1);
9726   } else if (isOperationLegalOrCustom(HiOp, VT)) {
9727     Lo = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
9728     Hi = DAG.getNode(HiOp, dl, VT, LHS, RHS);
9729   } else if (VT.isVector()) {
9730     return SDValue();
9731   } else {
9732     report_fatal_error("Unable to expand fixed point multiplication.");
9733   }
9734 
9735   if (Scale == VTSize)
9736     // Result is just the top half since we'd be shifting by the width of the
9737     // operand. Overflow impossible so this works for both UMULFIX and
9738     // UMULFIXSAT.
9739     return Hi;
9740 
9741   // The result will need to be shifted right by the scale since both operands
9742   // are scaled. The result is given to us in 2 halves, so we only want part of
9743   // both in the result.
9744   EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
9745   SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
9746                                DAG.getConstant(Scale, dl, ShiftTy));
9747   if (!Saturating)
9748     return Result;
9749 
9750   if (!Signed) {
9751     // Unsigned overflow happened if the upper (VTSize - Scale) bits (of the
9752     // widened multiplication) aren't all zeroes.
9753 
9754     // Saturate to max if ((Hi >> Scale) != 0),
9755     // which is the same as if (Hi > ((1 << Scale) - 1))
9756     APInt MaxVal = APInt::getMaxValue(VTSize);
9757     SDValue LowMask = DAG.getConstant(APInt::getLowBitsSet(VTSize, Scale),
9758                                       dl, VT);
9759     Result = DAG.getSelectCC(dl, Hi, LowMask,
9760                              DAG.getConstant(MaxVal, dl, VT), Result,
9761                              ISD::SETUGT);
9762 
9763     return Result;
9764   }
9765 
9766   // Signed overflow happened if the upper (VTSize - Scale + 1) bits (of the
9767   // widened multiplication) aren't all ones or all zeroes.
9768 
9769   SDValue SatMin = DAG.getConstant(APInt::getSignedMinValue(VTSize), dl, VT);
9770   SDValue SatMax = DAG.getConstant(APInt::getSignedMaxValue(VTSize), dl, VT);
9771 
9772   if (Scale == 0) {
9773     SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, Lo,
9774                                DAG.getConstant(VTSize - 1, dl, ShiftTy));
9775     SDValue Overflow = DAG.getSetCC(dl, BoolVT, Hi, Sign, ISD::SETNE);
9776     // Saturated to SatMin if wide product is negative, and SatMax if wide
9777     // product is positive ...
9778     SDValue Zero = DAG.getConstant(0, dl, VT);
9779     SDValue ResultIfOverflow = DAG.getSelectCC(dl, Hi, Zero, SatMin, SatMax,
9780                                                ISD::SETLT);
9781     // ... but only if we overflowed.
9782     return DAG.getSelect(dl, VT, Overflow, ResultIfOverflow, Result);
9783   }
9784 
9785   //  We handled Scale==0 above so all the bits to examine is in Hi.
9786 
9787   // Saturate to max if ((Hi >> (Scale - 1)) > 0),
9788   // which is the same as if (Hi > (1 << (Scale - 1)) - 1)
9789   SDValue LowMask = DAG.getConstant(APInt::getLowBitsSet(VTSize, Scale - 1),
9790                                     dl, VT);
9791   Result = DAG.getSelectCC(dl, Hi, LowMask, SatMax, Result, ISD::SETGT);
9792   // Saturate to min if (Hi >> (Scale - 1)) < -1),
9793   // which is the same as if (HI < (-1 << (Scale - 1))
9794   SDValue HighMask =
9795       DAG.getConstant(APInt::getHighBitsSet(VTSize, VTSize - Scale + 1),
9796                       dl, VT);
9797   Result = DAG.getSelectCC(dl, Hi, HighMask, SatMin, Result, ISD::SETLT);
9798   return Result;
9799 }
9800 
9801 SDValue
expandFixedPointDiv(unsigned Opcode,const SDLoc & dl,SDValue LHS,SDValue RHS,unsigned Scale,SelectionDAG & DAG) const9802 TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl,
9803                                     SDValue LHS, SDValue RHS,
9804                                     unsigned Scale, SelectionDAG &DAG) const {
9805   assert((Opcode == ISD::SDIVFIX || Opcode == ISD::SDIVFIXSAT ||
9806           Opcode == ISD::UDIVFIX || Opcode == ISD::UDIVFIXSAT) &&
9807          "Expected a fixed point division opcode");
9808 
9809   EVT VT = LHS.getValueType();
9810   bool Signed = Opcode == ISD::SDIVFIX || Opcode == ISD::SDIVFIXSAT;
9811   bool Saturating = Opcode == ISD::SDIVFIXSAT || Opcode == ISD::UDIVFIXSAT;
9812   EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9813 
9814   // If there is enough room in the type to upscale the LHS or downscale the
9815   // RHS before the division, we can perform it in this type without having to
9816   // resize. For signed operations, the LHS headroom is the number of
9817   // redundant sign bits, and for unsigned ones it is the number of zeroes.
9818   // The headroom for the RHS is the number of trailing zeroes.
9819   unsigned LHSLead = Signed ? DAG.ComputeNumSignBits(LHS) - 1
9820                             : DAG.computeKnownBits(LHS).countMinLeadingZeros();
9821   unsigned RHSTrail = DAG.computeKnownBits(RHS).countMinTrailingZeros();
9822 
9823   // For signed saturating operations, we need to be able to detect true integer
9824   // division overflow; that is, when you have MIN / -EPS. However, this
9825   // is undefined behavior and if we emit divisions that could take such
9826   // values it may cause undesired behavior (arithmetic exceptions on x86, for
9827   // example).
9828   // Avoid this by requiring an extra bit so that we never get this case.
9829   // FIXME: This is a bit unfortunate as it means that for an 8-bit 7-scale
9830   // signed saturating division, we need to emit a whopping 32-bit division.
9831   if (LHSLead + RHSTrail < Scale + (unsigned)(Saturating && Signed))
9832     return SDValue();
9833 
9834   unsigned LHSShift = std::min(LHSLead, Scale);
9835   unsigned RHSShift = Scale - LHSShift;
9836 
9837   // At this point, we know that if we shift the LHS up by LHSShift and the
9838   // RHS down by RHSShift, we can emit a regular division with a final scaling
9839   // factor of Scale.
9840 
9841   EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
9842   if (LHSShift)
9843     LHS = DAG.getNode(ISD::SHL, dl, VT, LHS,
9844                       DAG.getConstant(LHSShift, dl, ShiftTy));
9845   if (RHSShift)
9846     RHS = DAG.getNode(Signed ? ISD::SRA : ISD::SRL, dl, VT, RHS,
9847                       DAG.getConstant(RHSShift, dl, ShiftTy));
9848 
9849   SDValue Quot;
9850   if (Signed) {
9851     // For signed operations, if the resulting quotient is negative and the
9852     // remainder is nonzero, subtract 1 from the quotient to round towards
9853     // negative infinity.
9854     SDValue Rem;
9855     // FIXME: Ideally we would always produce an SDIVREM here, but if the
9856     // type isn't legal, SDIVREM cannot be expanded. There is no reason why
9857     // we couldn't just form a libcall, but the type legalizer doesn't do it.
9858     if (isTypeLegal(VT) &&
9859         isOperationLegalOrCustom(ISD::SDIVREM, VT)) {
9860       Quot = DAG.getNode(ISD::SDIVREM, dl,
9861                          DAG.getVTList(VT, VT),
9862                          LHS, RHS);
9863       Rem = Quot.getValue(1);
9864       Quot = Quot.getValue(0);
9865     } else {
9866       Quot = DAG.getNode(ISD::SDIV, dl, VT,
9867                          LHS, RHS);
9868       Rem = DAG.getNode(ISD::SREM, dl, VT,
9869                         LHS, RHS);
9870     }
9871     SDValue Zero = DAG.getConstant(0, dl, VT);
9872     SDValue RemNonZero = DAG.getSetCC(dl, BoolVT, Rem, Zero, ISD::SETNE);
9873     SDValue LHSNeg = DAG.getSetCC(dl, BoolVT, LHS, Zero, ISD::SETLT);
9874     SDValue RHSNeg = DAG.getSetCC(dl, BoolVT, RHS, Zero, ISD::SETLT);
9875     SDValue QuotNeg = DAG.getNode(ISD::XOR, dl, BoolVT, LHSNeg, RHSNeg);
9876     SDValue Sub1 = DAG.getNode(ISD::SUB, dl, VT, Quot,
9877                                DAG.getConstant(1, dl, VT));
9878     Quot = DAG.getSelect(dl, VT,
9879                          DAG.getNode(ISD::AND, dl, BoolVT, RemNonZero, QuotNeg),
9880                          Sub1, Quot);
9881   } else
9882     Quot = DAG.getNode(ISD::UDIV, dl, VT,
9883                        LHS, RHS);
9884 
9885   return Quot;
9886 }
9887 
expandUADDSUBO(SDNode * Node,SDValue & Result,SDValue & Overflow,SelectionDAG & DAG) const9888 void TargetLowering::expandUADDSUBO(
9889     SDNode *Node, SDValue &Result, SDValue &Overflow, SelectionDAG &DAG) const {
9890   SDLoc dl(Node);
9891   SDValue LHS = Node->getOperand(0);
9892   SDValue RHS = Node->getOperand(1);
9893   bool IsAdd = Node->getOpcode() == ISD::UADDO;
9894 
9895   // If ADD/SUBCARRY is legal, use that instead.
9896   unsigned OpcCarry = IsAdd ? ISD::ADDCARRY : ISD::SUBCARRY;
9897   if (isOperationLegalOrCustom(OpcCarry, Node->getValueType(0))) {
9898     SDValue CarryIn = DAG.getConstant(0, dl, Node->getValueType(1));
9899     SDValue NodeCarry = DAG.getNode(OpcCarry, dl, Node->getVTList(),
9900                                     { LHS, RHS, CarryIn });
9901     Result = SDValue(NodeCarry.getNode(), 0);
9902     Overflow = SDValue(NodeCarry.getNode(), 1);
9903     return;
9904   }
9905 
9906   Result = DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, dl,
9907                             LHS.getValueType(), LHS, RHS);
9908 
9909   EVT ResultType = Node->getValueType(1);
9910   EVT SetCCType = getSetCCResultType(
9911       DAG.getDataLayout(), *DAG.getContext(), Node->getValueType(0));
9912   SDValue SetCC;
9913   if (IsAdd && isOneConstant(RHS)) {
9914     // Special case: uaddo X, 1 overflowed if X+1 is 0. This potential reduces
9915     // the live range of X. We assume comparing with 0 is cheap.
9916     // The general case (X + C) < C is not necessarily beneficial. Although we
9917     // reduce the live range of X, we may introduce the materialization of
9918     // constant C.
9919     SetCC =
9920         DAG.getSetCC(dl, SetCCType, Result,
9921                      DAG.getConstant(0, dl, Node->getValueType(0)), ISD::SETEQ);
9922   } else {
9923     ISD::CondCode CC = IsAdd ? ISD::SETULT : ISD::SETUGT;
9924     SetCC = DAG.getSetCC(dl, SetCCType, Result, LHS, CC);
9925   }
9926   Overflow = DAG.getBoolExtOrTrunc(SetCC, dl, ResultType, ResultType);
9927 }
9928 
expandSADDSUBO(SDNode * Node,SDValue & Result,SDValue & Overflow,SelectionDAG & DAG) const9929 void TargetLowering::expandSADDSUBO(
9930     SDNode *Node, SDValue &Result, SDValue &Overflow, SelectionDAG &DAG) const {
9931   SDLoc dl(Node);
9932   SDValue LHS = Node->getOperand(0);
9933   SDValue RHS = Node->getOperand(1);
9934   bool IsAdd = Node->getOpcode() == ISD::SADDO;
9935 
9936   Result = DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, dl,
9937                             LHS.getValueType(), LHS, RHS);
9938 
9939   EVT ResultType = Node->getValueType(1);
9940   EVT OType = getSetCCResultType(
9941       DAG.getDataLayout(), *DAG.getContext(), Node->getValueType(0));
9942 
9943   // If SADDSAT/SSUBSAT is legal, compare results to detect overflow.
9944   unsigned OpcSat = IsAdd ? ISD::SADDSAT : ISD::SSUBSAT;
9945   if (isOperationLegal(OpcSat, LHS.getValueType())) {
9946     SDValue Sat = DAG.getNode(OpcSat, dl, LHS.getValueType(), LHS, RHS);
9947     SDValue SetCC = DAG.getSetCC(dl, OType, Result, Sat, ISD::SETNE);
9948     Overflow = DAG.getBoolExtOrTrunc(SetCC, dl, ResultType, ResultType);
9949     return;
9950   }
9951 
9952   SDValue Zero = DAG.getConstant(0, dl, LHS.getValueType());
9953 
9954   // For an addition, the result should be less than one of the operands (LHS)
9955   // if and only if the other operand (RHS) is negative, otherwise there will
9956   // be overflow.
9957   // For a subtraction, the result should be less than one of the operands
9958   // (LHS) if and only if the other operand (RHS) is (non-zero) positive,
9959   // otherwise there will be overflow.
9960   SDValue ResultLowerThanLHS = DAG.getSetCC(dl, OType, Result, LHS, ISD::SETLT);
9961   SDValue ConditionRHS =
9962       DAG.getSetCC(dl, OType, RHS, Zero, IsAdd ? ISD::SETLT : ISD::SETGT);
9963 
9964   Overflow = DAG.getBoolExtOrTrunc(
9965       DAG.getNode(ISD::XOR, dl, OType, ConditionRHS, ResultLowerThanLHS), dl,
9966       ResultType, ResultType);
9967 }
9968 
expandMULO(SDNode * Node,SDValue & Result,SDValue & Overflow,SelectionDAG & DAG) const9969 bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
9970                                 SDValue &Overflow, SelectionDAG &DAG) const {
9971   SDLoc dl(Node);
9972   EVT VT = Node->getValueType(0);
9973   EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9974   SDValue LHS = Node->getOperand(0);
9975   SDValue RHS = Node->getOperand(1);
9976   bool isSigned = Node->getOpcode() == ISD::SMULO;
9977 
9978   // For power-of-two multiplications we can use a simpler shift expansion.
9979   if (ConstantSDNode *RHSC = isConstOrConstSplat(RHS)) {
9980     const APInt &C = RHSC->getAPIntValue();
9981     // mulo(X, 1 << S) -> { X << S, (X << S) >> S != X }
9982     if (C.isPowerOf2()) {
9983       // smulo(x, signed_min) is same as umulo(x, signed_min).
9984       bool UseArithShift = isSigned && !C.isMinSignedValue();
9985       EVT ShiftAmtTy = getShiftAmountTy(VT, DAG.getDataLayout());
9986       SDValue ShiftAmt = DAG.getConstant(C.logBase2(), dl, ShiftAmtTy);
9987       Result = DAG.getNode(ISD::SHL, dl, VT, LHS, ShiftAmt);
9988       Overflow = DAG.getSetCC(dl, SetCCVT,
9989           DAG.getNode(UseArithShift ? ISD::SRA : ISD::SRL,
9990                       dl, VT, Result, ShiftAmt),
9991           LHS, ISD::SETNE);
9992       return true;
9993     }
9994   }
9995 
9996   EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits() * 2);
9997   if (VT.isVector())
9998     WideVT =
9999         EVT::getVectorVT(*DAG.getContext(), WideVT, VT.getVectorElementCount());
10000 
10001   SDValue BottomHalf;
10002   SDValue TopHalf;
10003   static const unsigned Ops[2][3] =
10004       { { ISD::MULHU, ISD::UMUL_LOHI, ISD::ZERO_EXTEND },
10005         { ISD::MULHS, ISD::SMUL_LOHI, ISD::SIGN_EXTEND }};
10006   if (isOperationLegalOrCustom(Ops[isSigned][0], VT)) {
10007     BottomHalf = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
10008     TopHalf = DAG.getNode(Ops[isSigned][0], dl, VT, LHS, RHS);
10009   } else if (isOperationLegalOrCustom(Ops[isSigned][1], VT)) {
10010     BottomHalf = DAG.getNode(Ops[isSigned][1], dl, DAG.getVTList(VT, VT), LHS,
10011                              RHS);
10012     TopHalf = BottomHalf.getValue(1);
10013   } else if (isTypeLegal(WideVT)) {
10014     LHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, LHS);
10015     RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS);
10016     SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS);
10017     BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
10018     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl,
10019         getShiftAmountTy(WideVT, DAG.getDataLayout()));
10020     TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT,
10021                           DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt));
10022   } else {
10023     if (VT.isVector())
10024       return false;
10025 
10026     // We can fall back to a libcall with an illegal type for the MUL if we
10027     // have a libcall big enough.
10028     // Also, we can fall back to a division in some cases, but that's a big
10029     // performance hit in the general case.
10030     RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
10031     if (WideVT == MVT::i16)
10032       LC = RTLIB::MUL_I16;
10033     else if (WideVT == MVT::i32)
10034       LC = RTLIB::MUL_I32;
10035     else if (WideVT == MVT::i64)
10036       LC = RTLIB::MUL_I64;
10037     else if (WideVT == MVT::i128)
10038       LC = RTLIB::MUL_I128;
10039     assert(LC != RTLIB::UNKNOWN_LIBCALL && "Cannot expand this operation!");
10040 
10041     SDValue HiLHS;
10042     SDValue HiRHS;
10043     if (isSigned) {
10044       // The high part is obtained by SRA'ing all but one of the bits of low
10045       // part.
10046       unsigned LoSize = VT.getFixedSizeInBits();
10047       HiLHS =
10048           DAG.getNode(ISD::SRA, dl, VT, LHS,
10049                       DAG.getConstant(LoSize - 1, dl,
10050                                       getPointerTy(DAG.getDataLayout())));
10051       HiRHS =
10052           DAG.getNode(ISD::SRA, dl, VT, RHS,
10053                       DAG.getConstant(LoSize - 1, dl,
10054                                       getPointerTy(DAG.getDataLayout())));
10055     } else {
10056         HiLHS = DAG.getConstant(0, dl, VT);
10057         HiRHS = DAG.getConstant(0, dl, VT);
10058     }
10059 
10060     // Here we're passing the 2 arguments explicitly as 4 arguments that are
10061     // pre-lowered to the correct types. This all depends upon WideVT not
10062     // being a legal type for the architecture and thus has to be split to
10063     // two arguments.
10064     SDValue Ret;
10065     TargetLowering::MakeLibCallOptions CallOptions;
10066     CallOptions.setSExt(isSigned);
10067     CallOptions.setIsPostTypeLegalization(true);
10068     if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
10069       // Halves of WideVT are packed into registers in different order
10070       // depending on platform endianness. This is usually handled by
10071       // the C calling convention, but we can't defer to it in
10072       // the legalizer.
10073       SDValue Args[] = { LHS, HiLHS, RHS, HiRHS };
10074       Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
10075     } else {
10076       SDValue Args[] = { HiLHS, LHS, HiRHS, RHS };
10077       Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
10078     }
10079     assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
10080            "Ret value is a collection of constituent nodes holding result.");
10081     if (DAG.getDataLayout().isLittleEndian()) {
10082       // Same as above.
10083       BottomHalf = Ret.getOperand(0);
10084       TopHalf = Ret.getOperand(1);
10085     } else {
10086       BottomHalf = Ret.getOperand(1);
10087       TopHalf = Ret.getOperand(0);
10088     }
10089   }
10090 
10091   Result = BottomHalf;
10092   if (isSigned) {
10093     SDValue ShiftAmt = DAG.getConstant(
10094         VT.getScalarSizeInBits() - 1, dl,
10095         getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout()));
10096     SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt);
10097     Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
10098   } else {
10099     Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf,
10100                             DAG.getConstant(0, dl, VT), ISD::SETNE);
10101   }
10102 
10103   // Truncate the result if SetCC returns a larger type than needed.
10104   EVT RType = Node->getValueType(1);
10105   if (RType.bitsLT(Overflow.getValueType()))
10106     Overflow = DAG.getNode(ISD::TRUNCATE, dl, RType, Overflow);
10107 
10108   assert(RType.getSizeInBits() == Overflow.getValueSizeInBits() &&
10109          "Unexpected result type for S/UMULO legalization");
10110   return true;
10111 }
10112 
expandVecReduce(SDNode * Node,SelectionDAG & DAG) const10113 SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
10114   SDLoc dl(Node);
10115   unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Node->getOpcode());
10116   SDValue Op = Node->getOperand(0);
10117   EVT VT = Op.getValueType();
10118 
10119   if (VT.isScalableVector())
10120     report_fatal_error(
10121         "Expanding reductions for scalable vectors is undefined.");
10122 
10123   // Try to use a shuffle reduction for power of two vectors.
10124   if (VT.isPow2VectorType()) {
10125     while (VT.getVectorNumElements() > 1) {
10126       EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
10127       if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
10128         break;
10129 
10130       SDValue Lo, Hi;
10131       std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
10132       Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi);
10133       VT = HalfVT;
10134     }
10135   }
10136 
10137   EVT EltVT = VT.getVectorElementType();
10138   unsigned NumElts = VT.getVectorNumElements();
10139 
10140   SmallVector<SDValue, 8> Ops;
10141   DAG.ExtractVectorElements(Op, Ops, 0, NumElts);
10142 
10143   SDValue Res = Ops[0];
10144   for (unsigned i = 1; i < NumElts; i++)
10145     Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Node->getFlags());
10146 
10147   // Result type may be wider than element type.
10148   if (EltVT != Node->getValueType(0))
10149     Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res);
10150   return Res;
10151 }
10152 
expandVecReduceSeq(SDNode * Node,SelectionDAG & DAG) const10153 SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
10154   SDLoc dl(Node);
10155   SDValue AccOp = Node->getOperand(0);
10156   SDValue VecOp = Node->getOperand(1);
10157   SDNodeFlags Flags = Node->getFlags();
10158 
10159   EVT VT = VecOp.getValueType();
10160   EVT EltVT = VT.getVectorElementType();
10161 
10162   if (VT.isScalableVector())
10163     report_fatal_error(
10164         "Expanding reductions for scalable vectors is undefined.");
10165 
10166   unsigned NumElts = VT.getVectorNumElements();
10167 
10168   SmallVector<SDValue, 8> Ops;
10169   DAG.ExtractVectorElements(VecOp, Ops, 0, NumElts);
10170 
10171   unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Node->getOpcode());
10172 
10173   SDValue Res = AccOp;
10174   for (unsigned i = 0; i < NumElts; i++)
10175     Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
10176 
10177   return Res;
10178 }
10179 
expandREM(SDNode * Node,SDValue & Result,SelectionDAG & DAG) const10180 bool TargetLowering::expandREM(SDNode *Node, SDValue &Result,
10181                                SelectionDAG &DAG) const {
10182   EVT VT = Node->getValueType(0);
10183   SDLoc dl(Node);
10184   bool isSigned = Node->getOpcode() == ISD::SREM;
10185   unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV;
10186   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
10187   SDValue Dividend = Node->getOperand(0);
10188   SDValue Divisor = Node->getOperand(1);
10189   if (isOperationLegalOrCustom(DivRemOpc, VT)) {
10190     SDVTList VTs = DAG.getVTList(VT, VT);
10191     Result = DAG.getNode(DivRemOpc, dl, VTs, Dividend, Divisor).getValue(1);
10192     return true;
10193   }
10194   if (isOperationLegalOrCustom(DivOpc, VT)) {
10195     // X % Y -> X-X/Y*Y
10196     SDValue Divide = DAG.getNode(DivOpc, dl, VT, Dividend, Divisor);
10197     SDValue Mul = DAG.getNode(ISD::MUL, dl, VT, Divide, Divisor);
10198     Result = DAG.getNode(ISD::SUB, dl, VT, Dividend, Mul);
10199     return true;
10200   }
10201   return false;
10202 }
10203 
expandFP_TO_INT_SAT(SDNode * Node,SelectionDAG & DAG) const10204 SDValue TargetLowering::expandFP_TO_INT_SAT(SDNode *Node,
10205                                             SelectionDAG &DAG) const {
10206   bool IsSigned = Node->getOpcode() == ISD::FP_TO_SINT_SAT;
10207   SDLoc dl(SDValue(Node, 0));
10208   SDValue Src = Node->getOperand(0);
10209 
10210   // DstVT is the result type, while SatVT is the size to which we saturate
10211   EVT SrcVT = Src.getValueType();
10212   EVT DstVT = Node->getValueType(0);
10213 
10214   EVT SatVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
10215   unsigned SatWidth = SatVT.getScalarSizeInBits();
10216   unsigned DstWidth = DstVT.getScalarSizeInBits();
10217   assert(SatWidth <= DstWidth &&
10218          "Expected saturation width smaller than result width");
10219 
10220   // Determine minimum and maximum integer values and their corresponding
10221   // floating-point values.
10222   APInt MinInt, MaxInt;
10223   if (IsSigned) {
10224     MinInt = APInt::getSignedMinValue(SatWidth).sext(DstWidth);
10225     MaxInt = APInt::getSignedMaxValue(SatWidth).sext(DstWidth);
10226   } else {
10227     MinInt = APInt::getMinValue(SatWidth).zext(DstWidth);
10228     MaxInt = APInt::getMaxValue(SatWidth).zext(DstWidth);
10229   }
10230 
10231   // We cannot risk emitting FP_TO_XINT nodes with a source VT of f16, as
10232   // libcall emission cannot handle this. Large result types will fail.
10233   if (SrcVT == MVT::f16) {
10234     Src = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, Src);
10235     SrcVT = Src.getValueType();
10236   }
10237 
10238   APFloat MinFloat(DAG.EVTToAPFloatSemantics(SrcVT));
10239   APFloat MaxFloat(DAG.EVTToAPFloatSemantics(SrcVT));
10240 
10241   APFloat::opStatus MinStatus =
10242       MinFloat.convertFromAPInt(MinInt, IsSigned, APFloat::rmTowardZero);
10243   APFloat::opStatus MaxStatus =
10244       MaxFloat.convertFromAPInt(MaxInt, IsSigned, APFloat::rmTowardZero);
10245   bool AreExactFloatBounds = !(MinStatus & APFloat::opStatus::opInexact) &&
10246                              !(MaxStatus & APFloat::opStatus::opInexact);
10247 
10248   SDValue MinFloatNode = DAG.getConstantFP(MinFloat, dl, SrcVT);
10249   SDValue MaxFloatNode = DAG.getConstantFP(MaxFloat, dl, SrcVT);
10250 
10251   // If the integer bounds are exactly representable as floats and min/max are
10252   // legal, emit a min+max+fptoi sequence. Otherwise we have to use a sequence
10253   // of comparisons and selects.
10254   bool MinMaxLegal = isOperationLegal(ISD::FMINNUM, SrcVT) &&
10255                      isOperationLegal(ISD::FMAXNUM, SrcVT);
10256   if (AreExactFloatBounds && MinMaxLegal) {
10257     SDValue Clamped = Src;
10258 
10259     // Clamp Src by MinFloat from below. If Src is NaN the result is MinFloat.
10260     Clamped = DAG.getNode(ISD::FMAXNUM, dl, SrcVT, Clamped, MinFloatNode);
10261     // Clamp by MaxFloat from above. NaN cannot occur.
10262     Clamped = DAG.getNode(ISD::FMINNUM, dl, SrcVT, Clamped, MaxFloatNode);
10263     // Convert clamped value to integer.
10264     SDValue FpToInt = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT,
10265                                   dl, DstVT, Clamped);
10266 
10267     // In the unsigned case we're done, because we mapped NaN to MinFloat,
10268     // which will cast to zero.
10269     if (!IsSigned)
10270       return FpToInt;
10271 
10272     // Otherwise, select 0 if Src is NaN.
10273     SDValue ZeroInt = DAG.getConstant(0, dl, DstVT);
10274     return DAG.getSelectCC(dl, Src, Src, ZeroInt, FpToInt,
10275                            ISD::CondCode::SETUO);
10276   }
10277 
10278   SDValue MinIntNode = DAG.getConstant(MinInt, dl, DstVT);
10279   SDValue MaxIntNode = DAG.getConstant(MaxInt, dl, DstVT);
10280 
10281   // Result of direct conversion. The assumption here is that the operation is
10282   // non-trapping and it's fine to apply it to an out-of-range value if we
10283   // select it away later.
10284   SDValue FpToInt =
10285       DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, DstVT, Src);
10286 
10287   SDValue Select = FpToInt;
10288 
10289   // If Src ULT MinFloat, select MinInt. In particular, this also selects
10290   // MinInt if Src is NaN.
10291   Select = DAG.getSelectCC(dl, Src, MinFloatNode, MinIntNode, Select,
10292                            ISD::CondCode::SETULT);
10293   // If Src OGT MaxFloat, select MaxInt.
10294   Select = DAG.getSelectCC(dl, Src, MaxFloatNode, MaxIntNode, Select,
10295                            ISD::CondCode::SETOGT);
10296 
10297   // In the unsigned case we are done, because we mapped NaN to MinInt, which
10298   // is already zero.
10299   if (!IsSigned)
10300     return Select;
10301 
10302   // Otherwise, select 0 if Src is NaN.
10303   SDValue ZeroInt = DAG.getConstant(0, dl, DstVT);
10304   return DAG.getSelectCC(dl, Src, Src, ZeroInt, Select, ISD::CondCode::SETUO);
10305 }
10306 
expandVectorSplice(SDNode * Node,SelectionDAG & DAG) const10307 SDValue TargetLowering::expandVectorSplice(SDNode *Node,
10308                                            SelectionDAG &DAG) const {
10309   assert(Node->getOpcode() == ISD::VECTOR_SPLICE && "Unexpected opcode!");
10310   assert(Node->getValueType(0).isScalableVector() &&
10311          "Fixed length vector types expected to use SHUFFLE_VECTOR!");
10312 
10313   EVT VT = Node->getValueType(0);
10314   SDValue V1 = Node->getOperand(0);
10315   SDValue V2 = Node->getOperand(1);
10316   int64_t Imm = cast<ConstantSDNode>(Node->getOperand(2))->getSExtValue();
10317   SDLoc DL(Node);
10318 
10319   // Expand through memory thusly:
10320   //  Alloca CONCAT_VECTORS_TYPES(V1, V2) Ptr
10321   //  Store V1, Ptr
10322   //  Store V2, Ptr + sizeof(V1)
10323   //  If (Imm < 0)
10324   //    TrailingElts = -Imm
10325   //    Ptr = Ptr + sizeof(V1) - (TrailingElts * sizeof(VT.Elt))
10326   //  else
10327   //    Ptr = Ptr + (Imm * sizeof(VT.Elt))
10328   //  Res = Load Ptr
10329 
10330   Align Alignment = DAG.getReducedAlign(VT, /*UseABI=*/false);
10331 
10332   EVT MemVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
10333                                VT.getVectorElementCount() * 2);
10334   SDValue StackPtr = DAG.CreateStackTemporary(MemVT.getStoreSize(), Alignment);
10335   EVT PtrVT = StackPtr.getValueType();
10336   auto &MF = DAG.getMachineFunction();
10337   auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
10338   auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
10339 
10340   // Store the lo part of CONCAT_VECTORS(V1, V2)
10341   SDValue StoreV1 = DAG.getStore(DAG.getEntryNode(), DL, V1, StackPtr, PtrInfo);
10342   // Store the hi part of CONCAT_VECTORS(V1, V2)
10343   SDValue OffsetToV2 = DAG.getVScale(
10344       DL, PtrVT,
10345       APInt(PtrVT.getFixedSizeInBits(), VT.getStoreSize().getKnownMinValue()));
10346   SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, OffsetToV2);
10347   SDValue StoreV2 = DAG.getStore(StoreV1, DL, V2, StackPtr2, PtrInfo);
10348 
10349   if (Imm >= 0) {
10350     // Load back the required element. getVectorElementPointer takes care of
10351     // clamping the index if it's out-of-bounds.
10352     StackPtr = getVectorElementPointer(DAG, StackPtr, VT, Node->getOperand(2));
10353     // Load the spliced result
10354     return DAG.getLoad(VT, DL, StoreV2, StackPtr,
10355                        MachinePointerInfo::getUnknownStack(MF));
10356   }
10357 
10358   uint64_t TrailingElts = -Imm;
10359 
10360   // NOTE: TrailingElts must be clamped so as not to read outside of V1:V2.
10361   TypeSize EltByteSize = VT.getVectorElementType().getStoreSize();
10362   SDValue TrailingBytes =
10363       DAG.getConstant(TrailingElts * EltByteSize, DL, PtrVT);
10364 
10365   if (TrailingElts > VT.getVectorMinNumElements()) {
10366     SDValue VLBytes =
10367         DAG.getVScale(DL, PtrVT,
10368                       APInt(PtrVT.getFixedSizeInBits(),
10369                             VT.getStoreSize().getKnownMinValue()));
10370     TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VLBytes);
10371   }
10372 
10373   // Calculate the start address of the spliced result.
10374   StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
10375 
10376   // Load the spliced result
10377   return DAG.getLoad(VT, DL, StoreV2, StackPtr2,
10378                      MachinePointerInfo::getUnknownStack(MF));
10379 }
10380 
LegalizeSetCCCondCode(SelectionDAG & DAG,EVT VT,SDValue & LHS,SDValue & RHS,SDValue & CC,SDValue Mask,SDValue EVL,bool & NeedInvert,const SDLoc & dl,SDValue & Chain,bool IsSignaling) const10381 bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
10382                                            SDValue &LHS, SDValue &RHS,
10383                                            SDValue &CC, SDValue Mask,
10384                                            SDValue EVL, bool &NeedInvert,
10385                                            const SDLoc &dl, SDValue &Chain,
10386                                            bool IsSignaling) const {
10387   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10388   MVT OpVT = LHS.getSimpleValueType();
10389   ISD::CondCode CCCode = cast<CondCodeSDNode>(CC)->get();
10390   NeedInvert = false;
10391   assert(!EVL == !Mask && "VP Mask and EVL must either both be set or unset");
10392   bool IsNonVP = !EVL;
10393   switch (TLI.getCondCodeAction(CCCode, OpVT)) {
10394   default:
10395     llvm_unreachable("Unknown condition code action!");
10396   case TargetLowering::Legal:
10397     // Nothing to do.
10398     break;
10399   case TargetLowering::Expand: {
10400     ISD::CondCode InvCC = ISD::getSetCCSwappedOperands(CCCode);
10401     if (TLI.isCondCodeLegalOrCustom(InvCC, OpVT)) {
10402       std::swap(LHS, RHS);
10403       CC = DAG.getCondCode(InvCC);
10404       return true;
10405     }
10406     // Swapping operands didn't work. Try inverting the condition.
10407     bool NeedSwap = false;
10408     InvCC = getSetCCInverse(CCCode, OpVT);
10409     if (!TLI.isCondCodeLegalOrCustom(InvCC, OpVT)) {
10410       // If inverting the condition is not enough, try swapping operands
10411       // on top of it.
10412       InvCC = ISD::getSetCCSwappedOperands(InvCC);
10413       NeedSwap = true;
10414     }
10415     if (TLI.isCondCodeLegalOrCustom(InvCC, OpVT)) {
10416       CC = DAG.getCondCode(InvCC);
10417       NeedInvert = true;
10418       if (NeedSwap)
10419         std::swap(LHS, RHS);
10420       return true;
10421     }
10422 
10423     ISD::CondCode CC1 = ISD::SETCC_INVALID, CC2 = ISD::SETCC_INVALID;
10424     unsigned Opc = 0;
10425     switch (CCCode) {
10426     default:
10427       llvm_unreachable("Don't know how to expand this condition!");
10428     case ISD::SETUO:
10429       if (TLI.isCondCodeLegal(ISD::SETUNE, OpVT)) {
10430         CC1 = ISD::SETUNE;
10431         CC2 = ISD::SETUNE;
10432         Opc = ISD::OR;
10433         break;
10434       }
10435       assert(TLI.isCondCodeLegal(ISD::SETOEQ, OpVT) &&
10436              "If SETUE is expanded, SETOEQ or SETUNE must be legal!");
10437       NeedInvert = true;
10438       [[fallthrough]];
10439     case ISD::SETO:
10440       assert(TLI.isCondCodeLegal(ISD::SETOEQ, OpVT) &&
10441              "If SETO is expanded, SETOEQ must be legal!");
10442       CC1 = ISD::SETOEQ;
10443       CC2 = ISD::SETOEQ;
10444       Opc = ISD::AND;
10445       break;
10446     case ISD::SETONE:
10447     case ISD::SETUEQ:
10448       // If the SETUO or SETO CC isn't legal, we might be able to use
10449       // SETOGT || SETOLT, inverting the result for SETUEQ. We only need one
10450       // of SETOGT/SETOLT to be legal, the other can be emulated by swapping
10451       // the operands.
10452       CC2 = ((unsigned)CCCode & 0x8U) ? ISD::SETUO : ISD::SETO;
10453       if (!TLI.isCondCodeLegal(CC2, OpVT) &&
10454           (TLI.isCondCodeLegal(ISD::SETOGT, OpVT) ||
10455            TLI.isCondCodeLegal(ISD::SETOLT, OpVT))) {
10456         CC1 = ISD::SETOGT;
10457         CC2 = ISD::SETOLT;
10458         Opc = ISD::OR;
10459         NeedInvert = ((unsigned)CCCode & 0x8U);
10460         break;
10461       }
10462       [[fallthrough]];
10463     case ISD::SETOEQ:
10464     case ISD::SETOGT:
10465     case ISD::SETOGE:
10466     case ISD::SETOLT:
10467     case ISD::SETOLE:
10468     case ISD::SETUNE:
10469     case ISD::SETUGT:
10470     case ISD::SETUGE:
10471     case ISD::SETULT:
10472     case ISD::SETULE:
10473       // If we are floating point, assign and break, otherwise fall through.
10474       if (!OpVT.isInteger()) {
10475         // We can use the 4th bit to tell if we are the unordered
10476         // or ordered version of the opcode.
10477         CC2 = ((unsigned)CCCode & 0x8U) ? ISD::SETUO : ISD::SETO;
10478         Opc = ((unsigned)CCCode & 0x8U) ? ISD::OR : ISD::AND;
10479         CC1 = (ISD::CondCode)(((int)CCCode & 0x7) | 0x10);
10480         break;
10481       }
10482       // Fallthrough if we are unsigned integer.
10483       [[fallthrough]];
10484     case ISD::SETLE:
10485     case ISD::SETGT:
10486     case ISD::SETGE:
10487     case ISD::SETLT:
10488     case ISD::SETNE:
10489     case ISD::SETEQ:
10490       // If all combinations of inverting the condition and swapping operands
10491       // didn't work then we have no means to expand the condition.
10492       llvm_unreachable("Don't know how to expand this condition!");
10493     }
10494 
10495     SDValue SetCC1, SetCC2;
10496     if (CCCode != ISD::SETO && CCCode != ISD::SETUO) {
10497       // If we aren't the ordered or unorder operation,
10498       // then the pattern is (LHS CC1 RHS) Opc (LHS CC2 RHS).
10499       if (IsNonVP) {
10500         SetCC1 = DAG.getSetCC(dl, VT, LHS, RHS, CC1, Chain, IsSignaling);
10501         SetCC2 = DAG.getSetCC(dl, VT, LHS, RHS, CC2, Chain, IsSignaling);
10502       } else {
10503         SetCC1 = DAG.getSetCCVP(dl, VT, LHS, RHS, CC1, Mask, EVL);
10504         SetCC2 = DAG.getSetCCVP(dl, VT, LHS, RHS, CC2, Mask, EVL);
10505       }
10506     } else {
10507       // Otherwise, the pattern is (LHS CC1 LHS) Opc (RHS CC2 RHS)
10508       if (IsNonVP) {
10509         SetCC1 = DAG.getSetCC(dl, VT, LHS, LHS, CC1, Chain, IsSignaling);
10510         SetCC2 = DAG.getSetCC(dl, VT, RHS, RHS, CC2, Chain, IsSignaling);
10511       } else {
10512         SetCC1 = DAG.getSetCCVP(dl, VT, LHS, LHS, CC1, Mask, EVL);
10513         SetCC2 = DAG.getSetCCVP(dl, VT, RHS, RHS, CC2, Mask, EVL);
10514       }
10515     }
10516     if (Chain)
10517       Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, SetCC1.getValue(1),
10518                           SetCC2.getValue(1));
10519     if (IsNonVP)
10520       LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2);
10521     else {
10522       // Transform the binary opcode to the VP equivalent.
10523       assert((Opc == ISD::OR || Opc == ISD::AND) && "Unexpected opcode");
10524       Opc = Opc == ISD::OR ? ISD::VP_OR : ISD::VP_AND;
10525       LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2, Mask, EVL);
10526     }
10527     RHS = SDValue();
10528     CC = SDValue();
10529     return true;
10530   }
10531   }
10532   return false;
10533 }
10534