1package xiangshan.backend.fu 2 3import chisel3._ 4import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage} 5import chisel3.util._ 6import utils.SignExt 7import xiangshan.XSModule 8import xiangshan.backend.fu.util.CSA3_2 9 10/** A Radix-4 SRT Integer Divider 11 * 12 * 2 ~ (5 + (len+3)/2) cycles are needed for each division. 13 */ 14class SRT4DividerDataModule(len: Int) extends Module { 15 val io = IO(new Bundle() { 16 val src1, src2 = Input(UInt(len.W)) 17 val valid, sign, kill_w, kill_r, isHi, isW = Input(Bool()) 18 val in_ready = Output(Bool()) 19 val out_valid = Output(Bool()) 20 val out_data = Output(UInt(len.W)) 21 val out_ready = Input(Bool()) 22 }) 23 24 val (a, b, sign, valid, kill_w, kill_r, isHi, isW) = 25 (io.src1, io.src2, io.sign, io.valid, io.kill_w, io.kill_r, io.isHi, io.isW) 26 val in_fire = valid && io.in_ready 27 val out_fire = io.out_ready && io.out_valid 28 29 // s_pad_* is not used 30 val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_pad_1 :: s_pad_2 :: s_finish :: Nil = Enum(9) 31 require(s_finish.litValue() == 8) 32 33 val state = RegInit(s_idle) 34 val finished = state(3).asBool // state === s_finish 35 36 val cnt_next = Wire(UInt(log2Up((len + 3) / 2).W)) 37 val cnt = RegEnable(cnt_next, state === s_normlize || state === s_recurrence) 38 val rec_enough = cnt_next === 0.U 39 val newReq = in_fire 40 41 def abs(a: UInt, sign: Bool): (Bool, UInt) = { 42 val s = a(len - 1) && sign 43 (s, Mux(s, -a, a)) 44 } 45 46 val (aSign, aVal) = abs(a, sign) 47 val (bSign, bVal) = abs(b, sign) 48 val aSignReg = RegEnable(aSign, newReq) 49 val qSignReg = RegEnable(aSign ^ bSign, newReq) 50 val divZero = b === 0.U 51 val divZeroReg = RegEnable(divZero, newReq) 52 53 switch(state) { 54 is(s_idle) { 55 when(in_fire && !kill_w) { 56 state := Mux(divZero, s_finish, s_lzd) 57 } 58 } 59 is(s_lzd) { // leading zero detection 60 state := s_normlize 61 } 62 is(s_normlize) { // shift a/b 63 state := s_recurrence 64 } 65 is(s_recurrence) { // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 66 when(rec_enough) { 67 state := s_recovery_1 68 } 69 } 70 is(s_recovery_1) { // if rem < 0, rem = rem + d 71 state := s_recovery_2 72 } 73 is(s_recovery_2) { // recovery shift 74 state := s_finish 75 } 76 is(s_finish) { 77 when(out_fire) { 78 state := s_idle 79 } 80 } 81 } 82 when(kill_r) { 83 state := s_idle 84 } 85 86 /** Calculate abs(a)/abs(b) by recurrence 87 * 88 * ws, wc: partial remainder in carry-save form, 89 * in recurrence steps, ws/wc = 4ws[j]/4wc[j]; 90 * in recovery step, ws/wc = ws[j]/wc[j]; 91 * in final step, ws = abs(a)/abs(b). 92 * 93 * d: normlized divisor(1/2<=d<1) 94 * 95 * wLen = 3 integer bits + (len+1) frac bits 96 */ 97 def wLen = 3 + len + 1 98 99 val ws, wc = Reg(UInt(wLen.W)) 100 val ws_next, wc_next = Wire(UInt(wLen.W)) 101 val d = Reg(UInt(wLen.W)) 102 103 val aLeadingZeros = RegEnable( 104 next = PriorityEncoder(ws(len - 1, 0).asBools().reverse), 105 enable = state === s_lzd 106 ) 107 val bLeadingZeros = RegEnable( 108 next = PriorityEncoder(d(len - 1, 0).asBools().reverse), 109 enable = state === s_lzd 110 ) 111 val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt() 112 val isNegDiff = diff(diff.getWidth - 1) 113 val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt()) 114 val qBitsIsOdd = quotientBits(0) 115 val recoveryShift = RegEnable(len.U - bLeadingZeros, state === s_normlize) 116 val a_shifted, b_shifted = Wire(UInt(len.W)) 117 a_shifted := Mux(isNegDiff, 118 ws(len - 1, 0) << bLeadingZeros, 119 ws(len - 1, 0) << aLeadingZeros 120 ) 121 b_shifted := d(len - 1, 0) << bLeadingZeros 122 123 val rem_temp = ws + wc 124 val rem_fixed = Mux(rem_temp(wLen - 1), rem_temp + d, rem_temp) 125 val rem_abs = (ws << recoveryShift) (2 * len, len + 1) 126 127 when(newReq) { 128 ws := Cat(0.U(4.W), Mux(divZero, a, aVal)) 129 wc := 0.U 130 d := Cat(0.U(4.W), bVal) 131 }.elsewhen(state === s_normlize) { 132 d := Cat(0.U(3.W), b_shifted, 0.U(1.W)) 133 ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1) 134 }.elsewhen(state === s_recurrence) { 135 ws := Mux(rec_enough, ws_next, ws_next << 2) 136 wc := Mux(rec_enough, wc_next, wc_next << 2) 137 }.elsewhen(state === s_recovery_1) { 138 ws := rem_fixed 139 }.elsewhen(state === s_recovery_2) { 140 ws := rem_abs 141 } 142 143 cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U) 144 145 /** Quotient selection 146 * 147 * the quotient selection table use truncated 7-bit remainder 148 * and 3-bit divisor 149 */ 150 val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5) 151 val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W)) 152 dx2 := d << 1 153 neg_d := (~d).asUInt() // add '1' in carry-save adder later 154 neg_dx2 := neg_d << 1 155 156 val q_sel = Wire(UInt(3.W)) 157 val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq( 158 sel_d -> 1.U(2.W), 159 sel_dx2 -> 2.U(2.W) 160 )) 161 162 val w_truncated = (ws(wLen - 1, wLen - 1 - 6) + wc(wLen - 1, wLen - 1 - 6)).asSInt() 163 val d_truncated = d(len - 1, len - 3) 164 165 val qSelTable = Array( 166 Array(12, 4, -4, -13), 167 Array(14, 4, -6, -15), 168 Array(15, 4, -6, -16), 169 Array(16, 4, -6, -18), 170 Array(18, 6, -8, -20), 171 Array(20, 6, -8, -20), 172 Array(20, 8, -8, -22), 173 Array(24, 8, -8, -24) 174 ) 175 176 // ge(x): w_truncated >= x 177 var ge = Map[Int, Bool]() 178 for (row <- qSelTable) { 179 for (k <- row) { 180 if (!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W))) 181 } 182 } 183 q_sel := MuxLookup(d_truncated, sel_0, 184 qSelTable.map(x => 185 MuxCase(sel_neg_dx2, Seq( 186 ge(x(0)) -> sel_dx2, 187 ge(x(1)) -> sel_d, 188 ge(x(2)) -> sel_0, 189 ge(x(3)) -> sel_neg_d 190 )) 191 ).zipWithIndex.map({ case (v, i) => i.U -> v }) 192 ) 193 194 /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder 195 * 196 * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 197 */ 198 val csa = Module(new CSA3_2(wLen)) 199 csa.io.in(0) := ws 200 csa.io.in(1) := Cat(wc(wLen - 1, 2), wc_adj) 201 csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq( 202 sel_d -> neg_d, 203 sel_dx2 -> neg_dx2, 204 sel_neg_d -> d, 205 sel_neg_dx2 -> dx2 206 )) 207 ws_next := csa.io.out(0) 208 wc_next := csa.io.out(1) << 1 209 210 // On the fly quotient conversion 211 val q, qm = Reg(UInt(len.W)) 212 when(newReq) { 213 q := 0.U 214 qm := 0.U 215 }.elsewhen(state === s_recurrence) { 216 val qMap = Seq( 217 sel_0 -> (q, 0), 218 sel_d -> (q, 1), 219 sel_dx2 -> (q, 2), 220 sel_neg_d -> (qm, 3), 221 sel_neg_dx2 -> (qm, 2) 222 ) 223 q := MuxLookup(q_sel, 0.U, 224 qMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W))) 225 ) 226 val qmMap = Seq( 227 sel_0 -> (qm, 3), 228 sel_d -> (q, 0), 229 sel_dx2 -> (q, 1), 230 sel_neg_d -> (qm, 2), 231 sel_neg_dx2 -> (qm, 1) 232 ) 233 qm := MuxLookup(q_sel, 0.U, 234 qmMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W))) 235 ) 236 }.elsewhen(state === s_recovery_1) { 237 q := Mux(rem_temp(wLen - 1), qm, q) 238 } 239 240 241 val remainder = Mux(aSignReg, -ws(len - 1, 0), ws(len - 1, 0)) 242 val quotient = Mux(qSignReg, -q, q) 243 244 val res = Mux(isHi, 245 Mux(divZeroReg, ws(len - 1, 0), remainder), 246 Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient) 247 ) 248 io.out_data := Mux(isW, 249 SignExt(res(31, 0), len), 250 res 251 ) 252 io.in_ready := state === s_idle 253 io.out_valid := finished // state === s_finish 254} 255 256class SRT4Divider(len: Int) extends AbstractDivider(len) { 257 258 val newReq = io.in.fire() 259 260 val uop = io.in.bits.uop 261 val uopReg = RegEnable(uop, newReq) 262 val ctrlReg = RegEnable(ctrl, newReq) 263 264 val divDataModule = Module(new SRT4DividerDataModule(len)) 265 266 val kill_w = uop.roqIdx.needFlush(io.redirectIn, io.flushIn) 267 val kill_r = !divDataModule.io.in_ready && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn) 268 269 divDataModule.io.src1 := io.in.bits.src(0) 270 divDataModule.io.src2 := io.in.bits.src(1) 271 divDataModule.io.valid := io.in.valid 272 divDataModule.io.sign := sign 273 divDataModule.io.kill_w := kill_w 274 divDataModule.io.kill_r := kill_r 275 divDataModule.io.isHi := ctrlReg.isHi 276 divDataModule.io.isW := ctrlReg.isW 277 divDataModule.io.out_ready := io.out.ready 278 279 io.in.ready := divDataModule.io.in_ready 280 io.out.valid := divDataModule.io.out_valid 281 io.out.bits.data := divDataModule.io.out_data 282 io.out.bits.uop := uopReg 283} 284