1package xiangshan.backend.fu 2 3import chisel3._ 4import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage} 5import chisel3.util._ 6import utils.SignExt 7import xiangshan.backend.fu.util.CSA3_2 8 9/** A Radix-4 SRT Integer Divider 10 * 11 * 2 ~ (5 + (len+3)/2) cycles are needed for each division. 12 */ 13class SRT4Divider(len: Int) extends AbstractDivider(len) { 14 15 val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_finish :: Nil = Enum(7) 16 val state = RegInit(s_idle) 17 val newReq = (state === s_idle) && io.in.fire() 18 val cnt_next = Wire(UInt(log2Up((len+3)/2).W)) 19 val cnt = RegEnable(cnt_next, state===s_normlize || state===s_recurrence) 20 val rec_enough = cnt_next === 0.U 21 22 def abs(a: UInt, sign: Bool): (Bool, UInt) = { 23 val s = a(len - 1) && sign 24 (s, Mux(s, -a, a)) 25 } 26 val (a, b) = (io.in.bits.src(0), io.in.bits.src(1)) 27 val uop = io.in.bits.uop 28 val (aSign, aVal) = abs(a, sign) 29 val (bSign, bVal) = abs(b, sign) 30 val aSignReg = RegEnable(aSign, newReq) 31 val qSignReg = RegEnable(aSign ^ bSign, newReq) 32 val uopReg = RegEnable(uop, newReq) 33 val ctrlReg = RegEnable(ctrl, newReq) 34 val divZero = b === 0.U 35 val divZeroReg = RegEnable(divZero, newReq) 36 37 val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn) 38 39 switch(state){ 40 is(s_idle){ 41 when (io.in.fire() && !io.in.bits.uop.roqIdx.needFlush(io.redirectIn, io.flushIn)) { 42 state := Mux(divZero, s_finish, s_lzd) 43 } 44 } 45 is(s_lzd){ // leading zero detection 46 state := s_normlize 47 } 48 is(s_normlize){ // shift a/b 49 state := s_recurrence 50 } 51 is(s_recurrence){ // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 52 when(rec_enough){ state := s_recovery_1 } 53 } 54 is(s_recovery_1){ // if rem < 0, rem = rem + d 55 state := s_recovery_2 56 } 57 is(s_recovery_2){ // recovery shift 58 state := s_finish 59 } 60 is(s_finish){ 61 when(io.out.fire()){ state := s_idle } 62 } 63 } 64 when(kill){ 65 state := s_idle 66 } 67 68 /** Calculate abs(a)/abs(b) by recurrence 69 * 70 * ws, wc: partial remainder in carry-save form, 71 * in recurrence steps, ws/wc = 4ws[j]/4wc[j]; 72 * in recovery step, ws/wc = ws[j]/wc[j]; 73 * in final step, ws = abs(a)/abs(b). 74 * 75 * d: normlized divisor(1/2<=d<1) 76 * 77 * wLen = 3 integer bits + (len+1) frac bits 78 */ 79 def wLen = 3 + len + 1 80 val ws, wc = Reg(UInt(wLen.W)) 81 val ws_next, wc_next = Wire(UInt(wLen.W)) 82 val d = Reg(UInt(wLen.W)) 83 84 val aLeadingZeros = RegEnable( 85 next = PriorityEncoder(ws(len-1, 0).asBools().reverse), 86 enable = state===s_lzd 87 ) 88 val bLeadingZeros = RegEnable( 89 next = PriorityEncoder(d(len-1, 0).asBools().reverse), 90 enable = state===s_lzd 91 ) 92 val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt() 93 val isNegDiff = diff(diff.getWidth - 1) 94 val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt()) 95 val qBitsIsOdd = quotientBits(0) 96 val recoveryShift = RegEnable(len.U - bLeadingZeros, state===s_normlize) 97 val a_shifted, b_shifted = Wire(UInt(len.W)) 98 a_shifted := Mux(isNegDiff, 99 ws(len-1, 0) << bLeadingZeros, 100 ws(len-1, 0) << aLeadingZeros 101 ) 102 b_shifted := d(len-1, 0) << bLeadingZeros 103 104 val rem_temp = ws + wc 105 val rem_fixed = Mux(rem_temp(wLen-1), rem_temp + d, rem_temp) 106 val rem_abs = (ws << recoveryShift)(2*len, len+1) 107 108 when(newReq){ 109 ws := Cat(0.U(4.W), Mux(divZero, a, aVal)) 110 wc := 0.U 111 d := Cat(0.U(4.W), bVal) 112 }.elsewhen(state === s_normlize){ 113 d := Cat(0.U(3.W), b_shifted, 0.U(1.W)) 114 ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1) 115 }.elsewhen(state === s_recurrence){ 116 ws := Mux(rec_enough, ws_next, ws_next << 2) 117 wc := Mux(rec_enough, wc_next, wc_next << 2) 118 }.elsewhen(state === s_recovery_1){ 119 ws := rem_fixed 120 }.elsewhen(state === s_recovery_2){ 121 ws := rem_abs 122 } 123 124 cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U) 125 126 /** Quotient selection 127 * 128 * the quotient selection table use truncated 7-bit remainder 129 * and 3-bit divisor 130 */ 131 val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5) 132 val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W)) 133 dx2 := d << 1 134 neg_d := (~d).asUInt() // add '1' in carry-save adder later 135 neg_dx2 := neg_d << 1 136 137 val q_sel = Wire(UInt(3.W)) 138 val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq( 139 sel_d -> 1.U(2.W), 140 sel_dx2 -> 2.U(2.W) 141 )) 142 143 val w_truncated = (ws(wLen-1, wLen-1-6) + wc(wLen-1, wLen-1-6)).asSInt() 144 val d_truncated = d(len-1, len-3) 145 146 val qSelTable = Array( 147 Array(12, 4, -4, -13), 148 Array(14, 4, -6, -15), 149 Array(15, 4, -6, -16), 150 Array(16, 4, -6, -18), 151 Array(18, 6, -8, -20), 152 Array(20, 6, -8, -20), 153 Array(20, 8, -8, -22), 154 Array(24, 8, -8, -24) 155 ) 156 157 // ge(x): w_truncated >= x 158 var ge = Map[Int, Bool]() 159 for(row <- qSelTable){ 160 for(k <- row){ 161 if(!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W))) 162 } 163 } 164 q_sel := MuxLookup(d_truncated, sel_0, 165 qSelTable.map(x => 166 MuxCase(sel_neg_dx2, Seq( 167 ge(x(0)) -> sel_dx2, 168 ge(x(1)) -> sel_d, 169 ge(x(2)) -> sel_0, 170 ge(x(3)) -> sel_neg_d 171 )) 172 ).zipWithIndex.map({case(v, i) => i.U -> v}) 173 ) 174 175 /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder 176 * 177 * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 178 */ 179 val csa = Module(new CSA3_2(wLen)) 180 csa.io.in(0) := ws 181 csa.io.in(1) := Cat(wc(wLen-1, 2), wc_adj) 182 csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq( 183 sel_d -> neg_d, 184 sel_dx2 -> neg_dx2, 185 sel_neg_d -> d, 186 sel_neg_dx2 -> dx2 187 )) 188 ws_next := csa.io.out(0) 189 wc_next := csa.io.out(1) << 1 190 191 // On the fly quotient conversion 192 val q, qm = Reg(UInt(len.W)) 193 when(newReq){ 194 q := 0.U 195 qm := 0.U 196 }.elsewhen(state === s_recurrence){ 197 val qMap = Seq( 198 sel_0 -> (q, 0), 199 sel_d -> (q, 1), 200 sel_dx2 -> (q, 2), 201 sel_neg_d -> (qm, 3), 202 sel_neg_dx2 -> (qm, 2) 203 ) 204 q := MuxLookup(q_sel, 0.U, 205 qMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) 206 ) 207 val qmMap = Seq( 208 sel_0 -> (qm, 3), 209 sel_d -> (q, 0), 210 sel_dx2 -> (q, 1), 211 sel_neg_d -> (qm, 2), 212 sel_neg_dx2 -> (qm, 1) 213 ) 214 qm := MuxLookup(q_sel, 0.U, 215 qmMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) 216 ) 217 }.elsewhen(state === s_recovery_1){ 218 q := Mux(rem_temp(wLen-1), qm, q) 219 } 220 221 222 val remainder = Mux(aSignReg, -ws(len-1, 0), ws(len-1, 0)) 223 val quotient = Mux(qSignReg, -q, q) 224 225 val res = Mux(ctrlReg.isHi, 226 Mux(divZeroReg, ws(len-1, 0), remainder), 227 Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient) 228 ) 229 230 io.in.ready := state===s_idle 231 io.out.valid := state===s_finish 232 io.out.bits.data := Mux(ctrlReg.isW, 233 SignExt(res(31, 0), len), 234 res 235 ) 236 io.out.bits.uop := uopReg 237 238} 239