1*5018a303SLinJiaweipackage xiangshan.backend.fu 2*5018a303SLinJiawei 3*5018a303SLinJiaweiimport chisel3._ 4*5018a303SLinJiaweiimport chisel3.util._ 5*5018a303SLinJiaweiimport utils.SignExt 6*5018a303SLinJiaweiimport xiangshan.backend.fu.fpu.util.CSA3_2 7*5018a303SLinJiawei 8*5018a303SLinJiawei/** A Radix-4 SRT Integer Divider 9*5018a303SLinJiawei * 10*5018a303SLinJiawei * 2 ~ (5 + (len+3)/2) cycles are needed for each division. 11*5018a303SLinJiawei */ 12*5018a303SLinJiaweiclass SRT4Divider(len: Int) extends AbstractDivider(len) { 13*5018a303SLinJiawei 14*5018a303SLinJiawei val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery :: s_finish :: Nil = Enum(6) 15*5018a303SLinJiawei val state = RegInit(s_idle) 16*5018a303SLinJiawei val newReq = (state === s_idle) && io.in.fire() 17*5018a303SLinJiawei val cnt_next = Wire(UInt(log2Up((len+3)/2).W)) 18*5018a303SLinJiawei val cnt = RegEnable(cnt_next, state===s_normlize || state===s_recurrence) 19*5018a303SLinJiawei val rec_enough = cnt_next === 0.U 20*5018a303SLinJiawei 21*5018a303SLinJiawei def abs(a: UInt, sign: Bool): (Bool, UInt) = { 22*5018a303SLinJiawei val s = a(len - 1) && sign 23*5018a303SLinJiawei (s, Mux(s, -a, a)) 24*5018a303SLinJiawei } 25*5018a303SLinJiawei val (a, b) = (io.in.bits.src(0), io.in.bits.src(1)) 26*5018a303SLinJiawei val uop = io.in.bits.uop 27*5018a303SLinJiawei val (aSign, aVal) = abs(a, sign) 28*5018a303SLinJiawei val (bSign, bVal) = abs(b, sign) 29*5018a303SLinJiawei val aSignReg = RegEnable(aSign, newReq) 30*5018a303SLinJiawei val qSignReg = RegEnable(aSign ^ bSign, newReq) 31*5018a303SLinJiawei val uopReg = RegEnable(uop, newReq) 32*5018a303SLinJiawei val ctrlReg = RegEnable(ctrl, newReq) 33*5018a303SLinJiawei val divZero = b === 0.U 34*5018a303SLinJiawei val divZeroReg = RegEnable(divZero, newReq) 35*5018a303SLinJiawei 36*5018a303SLinJiawei val kill = uopReg.roqIdx.needFlush(io.redirectIn) 37*5018a303SLinJiawei 38*5018a303SLinJiawei switch(state){ 39*5018a303SLinJiawei is(s_idle){ 40*5018a303SLinJiawei when(io.in.fire()){ state := Mux(divZero, s_finish, s_lzd) } 41*5018a303SLinJiawei } 42*5018a303SLinJiawei is(s_lzd){ // leading zero detection 43*5018a303SLinJiawei state := s_normlize 44*5018a303SLinJiawei } 45*5018a303SLinJiawei is(s_normlize){ // shift a/b 46*5018a303SLinJiawei state := s_recurrence 47*5018a303SLinJiawei } 48*5018a303SLinJiawei is(s_recurrence){ // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 49*5018a303SLinJiawei when(rec_enough){ state := s_recovery } 50*5018a303SLinJiawei } 51*5018a303SLinJiawei is(s_recovery){ // if rem < 0, rem = rem + d 52*5018a303SLinJiawei state := s_finish 53*5018a303SLinJiawei } 54*5018a303SLinJiawei is(s_finish){ 55*5018a303SLinJiawei when(io.out.fire()){ state := s_idle } 56*5018a303SLinJiawei } 57*5018a303SLinJiawei } 58*5018a303SLinJiawei when(kill){ 59*5018a303SLinJiawei state := s_idle 60*5018a303SLinJiawei } 61*5018a303SLinJiawei 62*5018a303SLinJiawei /** Calculate abs(a)/abs(b) by recurrence 63*5018a303SLinJiawei * 64*5018a303SLinJiawei * ws, wc: partial remainder in carry-save form, 65*5018a303SLinJiawei * in recurrence steps, ws/wc = 4ws[j]/4wc[j]; 66*5018a303SLinJiawei * in recovery step, ws/wc = ws[j]/wc[j]; 67*5018a303SLinJiawei * in final step, ws = abs(a)/abs(b). 68*5018a303SLinJiawei * 69*5018a303SLinJiawei * d: normlized divisor(1/2<=d<1) 70*5018a303SLinJiawei * 71*5018a303SLinJiawei * wLen = 3 integer bits + (len+1) frac bits 72*5018a303SLinJiawei */ 73*5018a303SLinJiawei def wLen = 3 + len + 1 74*5018a303SLinJiawei val ws, wc = Reg(UInt(wLen.W)) 75*5018a303SLinJiawei val ws_next, wc_next = Wire(UInt(wLen.W)) 76*5018a303SLinJiawei val d = Reg(UInt(wLen.W)) 77*5018a303SLinJiawei 78*5018a303SLinJiawei val aLeadingZeros = RegEnable( 79*5018a303SLinJiawei next = PriorityEncoder(ws(len-1, 0).asBools().reverse), 80*5018a303SLinJiawei enable = state===s_lzd 81*5018a303SLinJiawei ) 82*5018a303SLinJiawei val bLeadingZeros = RegEnable( 83*5018a303SLinJiawei next = PriorityEncoder(d(len-1, 0).asBools().reverse), 84*5018a303SLinJiawei enable = state===s_lzd 85*5018a303SLinJiawei ) 86*5018a303SLinJiawei val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt() 87*5018a303SLinJiawei val isNegDiff = diff(diff.getWidth - 1) 88*5018a303SLinJiawei val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt()) 89*5018a303SLinJiawei val qBitsIsOdd = quotientBits(0) 90*5018a303SLinJiawei val recoveryShift = RegEnable(len.U - bLeadingZeros, state===s_normlize) 91*5018a303SLinJiawei val a_shifted, b_shifted = Wire(UInt(len.W)) 92*5018a303SLinJiawei a_shifted := Mux(isNegDiff, 93*5018a303SLinJiawei ws(len-1, 0) << bLeadingZeros, 94*5018a303SLinJiawei ws(len-1, 0) << aLeadingZeros 95*5018a303SLinJiawei ) 96*5018a303SLinJiawei b_shifted := d(len-1, 0) << bLeadingZeros 97*5018a303SLinJiawei 98*5018a303SLinJiawei val rem_temp = ws + wc 99*5018a303SLinJiawei val rem_fixed = Mux(rem_temp(wLen-1), rem_temp + d, rem_temp) 100*5018a303SLinJiawei val rem_abs = (rem_fixed << recoveryShift)(2*len, len+1) 101*5018a303SLinJiawei 102*5018a303SLinJiawei when(newReq){ 103*5018a303SLinJiawei ws := Cat(0.U(4.W), Mux(divZero, a, aVal)) 104*5018a303SLinJiawei wc := 0.U 105*5018a303SLinJiawei d := Cat(0.U(4.W), bVal) 106*5018a303SLinJiawei }.elsewhen(state === s_normlize){ 107*5018a303SLinJiawei d := Cat(0.U(3.W), b_shifted, 0.U(1.W)) 108*5018a303SLinJiawei ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1) 109*5018a303SLinJiawei }.elsewhen(state === s_recurrence){ 110*5018a303SLinJiawei ws := Mux(rec_enough, ws_next, ws_next << 2) 111*5018a303SLinJiawei wc := Mux(rec_enough, wc_next, wc_next << 2) 112*5018a303SLinJiawei }.elsewhen(state === s_recovery){ 113*5018a303SLinJiawei ws := rem_abs 114*5018a303SLinJiawei } 115*5018a303SLinJiawei 116*5018a303SLinJiawei cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U) 117*5018a303SLinJiawei 118*5018a303SLinJiawei /** Quotient selection 119*5018a303SLinJiawei * 120*5018a303SLinJiawei * the quotient selection table use truncated 7-bit remainder 121*5018a303SLinJiawei * and 3-bit divisor 122*5018a303SLinJiawei */ 123*5018a303SLinJiawei val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5) 124*5018a303SLinJiawei val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W)) 125*5018a303SLinJiawei dx2 := d << 1 126*5018a303SLinJiawei neg_d := (~d).asUInt() // add '1' in carry-save adder later 127*5018a303SLinJiawei neg_dx2 := neg_d << 1 128*5018a303SLinJiawei 129*5018a303SLinJiawei val q_sel = Wire(UInt(3.W)) 130*5018a303SLinJiawei val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq( 131*5018a303SLinJiawei sel_d -> 1.U(2.W), 132*5018a303SLinJiawei sel_dx2 -> 2.U(2.W) 133*5018a303SLinJiawei )) 134*5018a303SLinJiawei 135*5018a303SLinJiawei val w_truncated = (ws(wLen-1, wLen-1-6) + wc(wLen-1, wLen-1-6)).asSInt() 136*5018a303SLinJiawei val d_truncated = d(len-1, len-3) 137*5018a303SLinJiawei 138*5018a303SLinJiawei val qSelTable = Array( 139*5018a303SLinJiawei Array(12, 4, -4, -13), 140*5018a303SLinJiawei Array(14, 4, -6, -15), 141*5018a303SLinJiawei Array(15, 4, -6, -16), 142*5018a303SLinJiawei Array(16, 4, -6, -18), 143*5018a303SLinJiawei Array(18, 6, -8, -20), 144*5018a303SLinJiawei Array(20, 6, -8, -20), 145*5018a303SLinJiawei Array(20, 8, -8, -22), 146*5018a303SLinJiawei Array(24, 8, -8, -24) 147*5018a303SLinJiawei ) 148*5018a303SLinJiawei 149*5018a303SLinJiawei // ge(x): w_truncated >= x 150*5018a303SLinJiawei var ge = Map[Int, Bool]() 151*5018a303SLinJiawei for(row <- qSelTable){ 152*5018a303SLinJiawei for(k <- row){ 153*5018a303SLinJiawei if(!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W))) 154*5018a303SLinJiawei } 155*5018a303SLinJiawei } 156*5018a303SLinJiawei q_sel := MuxLookup(d_truncated, sel_0, 157*5018a303SLinJiawei qSelTable.map(x => 158*5018a303SLinJiawei MuxCase(sel_neg_dx2, Seq( 159*5018a303SLinJiawei ge(x(0)) -> sel_dx2, 160*5018a303SLinJiawei ge(x(1)) -> sel_d, 161*5018a303SLinJiawei ge(x(2)) -> sel_0, 162*5018a303SLinJiawei ge(x(3)) -> sel_neg_d 163*5018a303SLinJiawei )) 164*5018a303SLinJiawei ).zipWithIndex.map({case(v, i) => i.U -> v}) 165*5018a303SLinJiawei ) 166*5018a303SLinJiawei 167*5018a303SLinJiawei /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder 168*5018a303SLinJiawei * 169*5018a303SLinJiawei * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 170*5018a303SLinJiawei */ 171*5018a303SLinJiawei val csa = Module(new CSA3_2(wLen)) 172*5018a303SLinJiawei csa.io.in(0) := ws 173*5018a303SLinJiawei csa.io.in(1) := Cat(wc(wLen-1, 2), wc_adj) 174*5018a303SLinJiawei csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq( 175*5018a303SLinJiawei sel_d -> neg_d, 176*5018a303SLinJiawei sel_dx2 -> neg_dx2, 177*5018a303SLinJiawei sel_neg_d -> d, 178*5018a303SLinJiawei sel_neg_dx2 -> dx2 179*5018a303SLinJiawei )) 180*5018a303SLinJiawei ws_next := csa.io.out(0) 181*5018a303SLinJiawei wc_next := csa.io.out(1) << 1 182*5018a303SLinJiawei 183*5018a303SLinJiawei // On the fly quotient conversion 184*5018a303SLinJiawei val q, qm = Reg(UInt(len.W)) 185*5018a303SLinJiawei when(newReq){ 186*5018a303SLinJiawei q := 0.U 187*5018a303SLinJiawei qm := 0.U 188*5018a303SLinJiawei }.elsewhen(state === s_recurrence){ 189*5018a303SLinJiawei val qMap = Seq( 190*5018a303SLinJiawei sel_0 -> (q, 0), 191*5018a303SLinJiawei sel_d -> (q, 1), 192*5018a303SLinJiawei sel_dx2 -> (q, 2), 193*5018a303SLinJiawei sel_neg_d -> (qm, 3), 194*5018a303SLinJiawei sel_neg_dx2 -> (qm, 2) 195*5018a303SLinJiawei ) 196*5018a303SLinJiawei q := MuxLookup(q_sel, 0.U, 197*5018a303SLinJiawei qMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) 198*5018a303SLinJiawei ) 199*5018a303SLinJiawei val qmMap = Seq( 200*5018a303SLinJiawei sel_0 -> (qm, 3), 201*5018a303SLinJiawei sel_d -> (q, 0), 202*5018a303SLinJiawei sel_dx2 -> (q, 1), 203*5018a303SLinJiawei sel_neg_d -> (qm, 2), 204*5018a303SLinJiawei sel_neg_dx2 -> (qm, 1) 205*5018a303SLinJiawei ) 206*5018a303SLinJiawei qm := MuxLookup(q_sel, 0.U, 207*5018a303SLinJiawei qmMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) 208*5018a303SLinJiawei ) 209*5018a303SLinJiawei }.elsewhen(state === s_recovery){ 210*5018a303SLinJiawei q := Mux(rem_temp(wLen-1), qm, q) 211*5018a303SLinJiawei } 212*5018a303SLinJiawei 213*5018a303SLinJiawei 214*5018a303SLinJiawei val remainder = Mux(aSignReg, -ws(len-1, 0), ws(len-1, 0)) 215*5018a303SLinJiawei val quotient = Mux(qSignReg, -q, q) 216*5018a303SLinJiawei 217*5018a303SLinJiawei val res = Mux(ctrlReg.isHi, 218*5018a303SLinJiawei Mux(divZeroReg, ws(len-1, 0), remainder), 219*5018a303SLinJiawei Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient) 220*5018a303SLinJiawei ) 221*5018a303SLinJiawei 222*5018a303SLinJiawei io.in.ready := state===s_idle 223*5018a303SLinJiawei io.out.valid := state===s_finish && !kill 224*5018a303SLinJiawei io.out.bits.data := Mux(ctrlReg.isW, 225*5018a303SLinJiawei SignExt(res(31, 0), len), 226*5018a303SLinJiawei res 227*5018a303SLinJiawei ) 228*5018a303SLinJiawei io.out.bits.uop := uopReg 229*5018a303SLinJiawei 230*5018a303SLinJiawei} 231