15018a303SLinJiaweipackage xiangshan.backend.fu 25018a303SLinJiawei 35018a303SLinJiaweiimport chisel3._ 45018a303SLinJiaweiimport chisel3.util._ 55018a303SLinJiaweiimport utils.SignExt 6*7f1506e3SLinJiaweiimport xiangshan.backend.fu.util.CSA3_2 75018a303SLinJiawei 85018a303SLinJiawei/** A Radix-4 SRT Integer Divider 95018a303SLinJiawei * 105018a303SLinJiawei * 2 ~ (5 + (len+3)/2) cycles are needed for each division. 115018a303SLinJiawei */ 125018a303SLinJiaweiclass SRT4Divider(len: Int) extends AbstractDivider(len) { 135018a303SLinJiawei 145018a303SLinJiawei val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery :: s_finish :: Nil = Enum(6) 155018a303SLinJiawei val state = RegInit(s_idle) 165018a303SLinJiawei val newReq = (state === s_idle) && io.in.fire() 175018a303SLinJiawei val cnt_next = Wire(UInt(log2Up((len+3)/2).W)) 185018a303SLinJiawei val cnt = RegEnable(cnt_next, state===s_normlize || state===s_recurrence) 195018a303SLinJiawei val rec_enough = cnt_next === 0.U 205018a303SLinJiawei 215018a303SLinJiawei def abs(a: UInt, sign: Bool): (Bool, UInt) = { 225018a303SLinJiawei val s = a(len - 1) && sign 235018a303SLinJiawei (s, Mux(s, -a, a)) 245018a303SLinJiawei } 255018a303SLinJiawei val (a, b) = (io.in.bits.src(0), io.in.bits.src(1)) 265018a303SLinJiawei val uop = io.in.bits.uop 275018a303SLinJiawei val (aSign, aVal) = abs(a, sign) 285018a303SLinJiawei val (bSign, bVal) = abs(b, sign) 295018a303SLinJiawei val aSignReg = RegEnable(aSign, newReq) 305018a303SLinJiawei val qSignReg = RegEnable(aSign ^ bSign, newReq) 315018a303SLinJiawei val uopReg = RegEnable(uop, newReq) 325018a303SLinJiawei val ctrlReg = RegEnable(ctrl, newReq) 335018a303SLinJiawei val divZero = b === 0.U 345018a303SLinJiawei val divZeroReg = RegEnable(divZero, newReq) 355018a303SLinJiawei 36ff8496b2SLinJiawei val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn) 375018a303SLinJiawei 385018a303SLinJiawei switch(state){ 395018a303SLinJiawei is(s_idle){ 405018a303SLinJiawei when(io.in.fire()){ state := Mux(divZero, s_finish, s_lzd) } 415018a303SLinJiawei } 425018a303SLinJiawei is(s_lzd){ // leading zero detection 435018a303SLinJiawei state := s_normlize 445018a303SLinJiawei } 455018a303SLinJiawei is(s_normlize){ // shift a/b 465018a303SLinJiawei state := s_recurrence 475018a303SLinJiawei } 485018a303SLinJiawei is(s_recurrence){ // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 495018a303SLinJiawei when(rec_enough){ state := s_recovery } 505018a303SLinJiawei } 515018a303SLinJiawei is(s_recovery){ // if rem < 0, rem = rem + d 525018a303SLinJiawei state := s_finish 535018a303SLinJiawei } 545018a303SLinJiawei is(s_finish){ 555018a303SLinJiawei when(io.out.fire()){ state := s_idle } 565018a303SLinJiawei } 575018a303SLinJiawei } 585018a303SLinJiawei when(kill){ 595018a303SLinJiawei state := s_idle 605018a303SLinJiawei } 615018a303SLinJiawei 625018a303SLinJiawei /** Calculate abs(a)/abs(b) by recurrence 635018a303SLinJiawei * 645018a303SLinJiawei * ws, wc: partial remainder in carry-save form, 655018a303SLinJiawei * in recurrence steps, ws/wc = 4ws[j]/4wc[j]; 665018a303SLinJiawei * in recovery step, ws/wc = ws[j]/wc[j]; 675018a303SLinJiawei * in final step, ws = abs(a)/abs(b). 685018a303SLinJiawei * 695018a303SLinJiawei * d: normlized divisor(1/2<=d<1) 705018a303SLinJiawei * 715018a303SLinJiawei * wLen = 3 integer bits + (len+1) frac bits 725018a303SLinJiawei */ 735018a303SLinJiawei def wLen = 3 + len + 1 745018a303SLinJiawei val ws, wc = Reg(UInt(wLen.W)) 755018a303SLinJiawei val ws_next, wc_next = Wire(UInt(wLen.W)) 765018a303SLinJiawei val d = Reg(UInt(wLen.W)) 775018a303SLinJiawei 785018a303SLinJiawei val aLeadingZeros = RegEnable( 795018a303SLinJiawei next = PriorityEncoder(ws(len-1, 0).asBools().reverse), 805018a303SLinJiawei enable = state===s_lzd 815018a303SLinJiawei ) 825018a303SLinJiawei val bLeadingZeros = RegEnable( 835018a303SLinJiawei next = PriorityEncoder(d(len-1, 0).asBools().reverse), 845018a303SLinJiawei enable = state===s_lzd 855018a303SLinJiawei ) 865018a303SLinJiawei val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt() 875018a303SLinJiawei val isNegDiff = diff(diff.getWidth - 1) 885018a303SLinJiawei val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt()) 895018a303SLinJiawei val qBitsIsOdd = quotientBits(0) 905018a303SLinJiawei val recoveryShift = RegEnable(len.U - bLeadingZeros, state===s_normlize) 915018a303SLinJiawei val a_shifted, b_shifted = Wire(UInt(len.W)) 925018a303SLinJiawei a_shifted := Mux(isNegDiff, 935018a303SLinJiawei ws(len-1, 0) << bLeadingZeros, 945018a303SLinJiawei ws(len-1, 0) << aLeadingZeros 955018a303SLinJiawei ) 965018a303SLinJiawei b_shifted := d(len-1, 0) << bLeadingZeros 975018a303SLinJiawei 985018a303SLinJiawei val rem_temp = ws + wc 995018a303SLinJiawei val rem_fixed = Mux(rem_temp(wLen-1), rem_temp + d, rem_temp) 1005018a303SLinJiawei val rem_abs = (rem_fixed << recoveryShift)(2*len, len+1) 1015018a303SLinJiawei 1025018a303SLinJiawei when(newReq){ 1035018a303SLinJiawei ws := Cat(0.U(4.W), Mux(divZero, a, aVal)) 1045018a303SLinJiawei wc := 0.U 1055018a303SLinJiawei d := Cat(0.U(4.W), bVal) 1065018a303SLinJiawei }.elsewhen(state === s_normlize){ 1075018a303SLinJiawei d := Cat(0.U(3.W), b_shifted, 0.U(1.W)) 1085018a303SLinJiawei ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1) 1095018a303SLinJiawei }.elsewhen(state === s_recurrence){ 1105018a303SLinJiawei ws := Mux(rec_enough, ws_next, ws_next << 2) 1115018a303SLinJiawei wc := Mux(rec_enough, wc_next, wc_next << 2) 1125018a303SLinJiawei }.elsewhen(state === s_recovery){ 1135018a303SLinJiawei ws := rem_abs 1145018a303SLinJiawei } 1155018a303SLinJiawei 1165018a303SLinJiawei cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U) 1175018a303SLinJiawei 1185018a303SLinJiawei /** Quotient selection 1195018a303SLinJiawei * 1205018a303SLinJiawei * the quotient selection table use truncated 7-bit remainder 1215018a303SLinJiawei * and 3-bit divisor 1225018a303SLinJiawei */ 1235018a303SLinJiawei val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5) 1245018a303SLinJiawei val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W)) 1255018a303SLinJiawei dx2 := d << 1 1265018a303SLinJiawei neg_d := (~d).asUInt() // add '1' in carry-save adder later 1275018a303SLinJiawei neg_dx2 := neg_d << 1 1285018a303SLinJiawei 1295018a303SLinJiawei val q_sel = Wire(UInt(3.W)) 1305018a303SLinJiawei val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq( 1315018a303SLinJiawei sel_d -> 1.U(2.W), 1325018a303SLinJiawei sel_dx2 -> 2.U(2.W) 1335018a303SLinJiawei )) 1345018a303SLinJiawei 1355018a303SLinJiawei val w_truncated = (ws(wLen-1, wLen-1-6) + wc(wLen-1, wLen-1-6)).asSInt() 1365018a303SLinJiawei val d_truncated = d(len-1, len-3) 1375018a303SLinJiawei 1385018a303SLinJiawei val qSelTable = Array( 1395018a303SLinJiawei Array(12, 4, -4, -13), 1405018a303SLinJiawei Array(14, 4, -6, -15), 1415018a303SLinJiawei Array(15, 4, -6, -16), 1425018a303SLinJiawei Array(16, 4, -6, -18), 1435018a303SLinJiawei Array(18, 6, -8, -20), 1445018a303SLinJiawei Array(20, 6, -8, -20), 1455018a303SLinJiawei Array(20, 8, -8, -22), 1465018a303SLinJiawei Array(24, 8, -8, -24) 1475018a303SLinJiawei ) 1485018a303SLinJiawei 1495018a303SLinJiawei // ge(x): w_truncated >= x 1505018a303SLinJiawei var ge = Map[Int, Bool]() 1515018a303SLinJiawei for(row <- qSelTable){ 1525018a303SLinJiawei for(k <- row){ 1535018a303SLinJiawei if(!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W))) 1545018a303SLinJiawei } 1555018a303SLinJiawei } 1565018a303SLinJiawei q_sel := MuxLookup(d_truncated, sel_0, 1575018a303SLinJiawei qSelTable.map(x => 1585018a303SLinJiawei MuxCase(sel_neg_dx2, Seq( 1595018a303SLinJiawei ge(x(0)) -> sel_dx2, 1605018a303SLinJiawei ge(x(1)) -> sel_d, 1615018a303SLinJiawei ge(x(2)) -> sel_0, 1625018a303SLinJiawei ge(x(3)) -> sel_neg_d 1635018a303SLinJiawei )) 1645018a303SLinJiawei ).zipWithIndex.map({case(v, i) => i.U -> v}) 1655018a303SLinJiawei ) 1665018a303SLinJiawei 1675018a303SLinJiawei /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder 1685018a303SLinJiawei * 1695018a303SLinJiawei * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 1705018a303SLinJiawei */ 1715018a303SLinJiawei val csa = Module(new CSA3_2(wLen)) 1725018a303SLinJiawei csa.io.in(0) := ws 1735018a303SLinJiawei csa.io.in(1) := Cat(wc(wLen-1, 2), wc_adj) 1745018a303SLinJiawei csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq( 1755018a303SLinJiawei sel_d -> neg_d, 1765018a303SLinJiawei sel_dx2 -> neg_dx2, 1775018a303SLinJiawei sel_neg_d -> d, 1785018a303SLinJiawei sel_neg_dx2 -> dx2 1795018a303SLinJiawei )) 1805018a303SLinJiawei ws_next := csa.io.out(0) 1815018a303SLinJiawei wc_next := csa.io.out(1) << 1 1825018a303SLinJiawei 1835018a303SLinJiawei // On the fly quotient conversion 1845018a303SLinJiawei val q, qm = Reg(UInt(len.W)) 1855018a303SLinJiawei when(newReq){ 1865018a303SLinJiawei q := 0.U 1875018a303SLinJiawei qm := 0.U 1885018a303SLinJiawei }.elsewhen(state === s_recurrence){ 1895018a303SLinJiawei val qMap = Seq( 1905018a303SLinJiawei sel_0 -> (q, 0), 1915018a303SLinJiawei sel_d -> (q, 1), 1925018a303SLinJiawei sel_dx2 -> (q, 2), 1935018a303SLinJiawei sel_neg_d -> (qm, 3), 1945018a303SLinJiawei sel_neg_dx2 -> (qm, 2) 1955018a303SLinJiawei ) 1965018a303SLinJiawei q := MuxLookup(q_sel, 0.U, 1975018a303SLinJiawei qMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) 1985018a303SLinJiawei ) 1995018a303SLinJiawei val qmMap = Seq( 2005018a303SLinJiawei sel_0 -> (qm, 3), 2015018a303SLinJiawei sel_d -> (q, 0), 2025018a303SLinJiawei sel_dx2 -> (q, 1), 2035018a303SLinJiawei sel_neg_d -> (qm, 2), 2045018a303SLinJiawei sel_neg_dx2 -> (qm, 1) 2055018a303SLinJiawei ) 2065018a303SLinJiawei qm := MuxLookup(q_sel, 0.U, 2075018a303SLinJiawei qmMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) 2085018a303SLinJiawei ) 2095018a303SLinJiawei }.elsewhen(state === s_recovery){ 2105018a303SLinJiawei q := Mux(rem_temp(wLen-1), qm, q) 2115018a303SLinJiawei } 2125018a303SLinJiawei 2135018a303SLinJiawei 2145018a303SLinJiawei val remainder = Mux(aSignReg, -ws(len-1, 0), ws(len-1, 0)) 2155018a303SLinJiawei val quotient = Mux(qSignReg, -q, q) 2165018a303SLinJiawei 2175018a303SLinJiawei val res = Mux(ctrlReg.isHi, 2185018a303SLinJiawei Mux(divZeroReg, ws(len-1, 0), remainder), 2195018a303SLinJiawei Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient) 2205018a303SLinJiawei ) 2215018a303SLinJiawei 2225018a303SLinJiawei io.in.ready := state===s_idle 2235018a303SLinJiawei io.out.valid := state===s_finish && !kill 2245018a303SLinJiawei io.out.bits.data := Mux(ctrlReg.isW, 2255018a303SLinJiawei SignExt(res(31, 0), len), 2265018a303SLinJiawei res 2275018a303SLinJiawei ) 2285018a303SLinJiawei io.out.bits.uop := uopReg 2295018a303SLinJiawei 2305018a303SLinJiawei} 231