1package xiangshan.backend.fu 2 3import chipsalliance.rocketchip.config.Parameters 4import chisel3._ 5import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage} 6import chisel3.util._ 7import utils.SignExt 8import xiangshan.XSModule 9import xiangshan.backend.fu.util.CSA3_2 10 11/** A Radix-4 SRT Integer Divider 12 * 13 * 2 ~ (5 + (len+3)/2) cycles are needed for each division. 14 */ 15class SRT4DividerDataModule(len: Int) extends Module { 16 val io = IO(new Bundle() { 17 val src1, src2 = Input(UInt(len.W)) 18 val valid, sign, kill_w, kill_r, isHi, isW = Input(Bool()) 19 val in_ready = Output(Bool()) 20 val out_valid = Output(Bool()) 21 val out_data = Output(UInt(len.W)) 22 val out_ready = Input(Bool()) 23 }) 24 25 val (a, b, sign, valid, kill_w, kill_r, isHi, isW) = 26 (io.src1, io.src2, io.sign, io.valid, io.kill_w, io.kill_r, io.isHi, io.isW) 27 val in_fire = valid && io.in_ready 28 val out_fire = io.out_ready && io.out_valid 29 30 // s_pad_* is not used 31 val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_pad_1 :: s_pad_2 :: s_finish :: Nil = Enum(9) 32 require(s_finish.litValue() == 8) 33 34 val state = RegInit(s_idle) 35 val finished = state(3).asBool // state === s_finish 36 37 val cnt_next = Wire(UInt(log2Up((len + 3) / 2).W)) 38 val cnt = RegEnable(cnt_next, state === s_normlize || state === s_recurrence) 39 val rec_enough = cnt_next === 0.U 40 val newReq = in_fire 41 42 def abs(a: UInt, sign: Bool): (Bool, UInt) = { 43 val s = a(len - 1) && sign 44 (s, Mux(s, -a, a)) 45 } 46 47 val (aSign, aVal) = abs(a, sign) 48 val (bSign, bVal) = abs(b, sign) 49 val aSignReg = RegEnable(aSign, newReq) 50 val qSignReg = RegEnable(aSign ^ bSign, newReq) 51 val divZero = b === 0.U 52 val divZeroReg = RegEnable(divZero, newReq) 53 54 switch(state) { 55 is(s_idle) { 56 when(in_fire && !kill_w) { 57 state := Mux(divZero, s_finish, s_lzd) 58 } 59 } 60 is(s_lzd) { // leading zero detection 61 state := s_normlize 62 } 63 is(s_normlize) { // shift a/b 64 state := s_recurrence 65 } 66 is(s_recurrence) { // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 67 when(rec_enough) { 68 state := s_recovery_1 69 } 70 } 71 is(s_recovery_1) { // if rem < 0, rem = rem + d 72 state := s_recovery_2 73 } 74 is(s_recovery_2) { // recovery shift 75 state := s_finish 76 } 77 is(s_finish) { 78 when(out_fire) { 79 state := s_idle 80 } 81 } 82 } 83 when(kill_r) { 84 state := s_idle 85 } 86 87 /** Calculate abs(a)/abs(b) by recurrence 88 * 89 * ws, wc: partial remainder in carry-save form, 90 * in recurrence steps, ws/wc = 4ws[j]/4wc[j]; 91 * in recovery step, ws/wc = ws[j]/wc[j]; 92 * in final step, ws = abs(a)/abs(b). 93 * 94 * d: normlized divisor(1/2<=d<1) 95 * 96 * wLen = 3 integer bits + (len+1) frac bits 97 */ 98 def wLen = 3 + len + 1 99 100 val ws, wc = Reg(UInt(wLen.W)) 101 val ws_next, wc_next = Wire(UInt(wLen.W)) 102 val d = Reg(UInt(wLen.W)) 103 104 val aLeadingZeros = RegEnable( 105 next = PriorityEncoder(ws(len - 1, 0).asBools().reverse), 106 enable = state === s_lzd 107 ) 108 val bLeadingZeros = RegEnable( 109 next = PriorityEncoder(d(len - 1, 0).asBools().reverse), 110 enable = state === s_lzd 111 ) 112 val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt() 113 val isNegDiff = diff(diff.getWidth - 1) 114 val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt()) 115 val qBitsIsOdd = quotientBits(0) 116 val recoveryShift = RegEnable(len.U - bLeadingZeros, state === s_normlize) 117 val a_shifted, b_shifted = Wire(UInt(len.W)) 118 a_shifted := Mux(isNegDiff, 119 ws(len - 1, 0) << bLeadingZeros, 120 ws(len - 1, 0) << aLeadingZeros 121 ) 122 b_shifted := d(len - 1, 0) << bLeadingZeros 123 124 val rem_temp = ws + wc 125 val rem_fixed = Mux(rem_temp(wLen - 1), rem_temp + d, rem_temp) 126 val rem_abs = (ws << recoveryShift) (2 * len, len + 1) 127 128 when(newReq) { 129 ws := Cat(0.U(4.W), Mux(divZero, a, aVal)) 130 wc := 0.U 131 d := Cat(0.U(4.W), bVal) 132 }.elsewhen(state === s_normlize) { 133 d := Cat(0.U(3.W), b_shifted, 0.U(1.W)) 134 ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1) 135 }.elsewhen(state === s_recurrence) { 136 ws := Mux(rec_enough, ws_next, ws_next << 2) 137 wc := Mux(rec_enough, wc_next, wc_next << 2) 138 }.elsewhen(state === s_recovery_1) { 139 ws := rem_fixed 140 }.elsewhen(state === s_recovery_2) { 141 ws := rem_abs 142 } 143 144 cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U) 145 146 /** Quotient selection 147 * 148 * the quotient selection table use truncated 7-bit remainder 149 * and 3-bit divisor 150 */ 151 val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5) 152 val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W)) 153 dx2 := d << 1 154 neg_d := (~d).asUInt() // add '1' in carry-save adder later 155 neg_dx2 := neg_d << 1 156 157 val q_sel = Wire(UInt(3.W)) 158 val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq( 159 sel_d -> 1.U(2.W), 160 sel_dx2 -> 2.U(2.W) 161 )) 162 163 val w_truncated = (ws(wLen - 1, wLen - 1 - 6) + wc(wLen - 1, wLen - 1 - 6)).asSInt() 164 val d_truncated = d(len - 1, len - 3) 165 166 val qSelTable = Array( 167 Array(12, 4, -4, -13), 168 Array(14, 4, -6, -15), 169 Array(15, 4, -6, -16), 170 Array(16, 4, -6, -18), 171 Array(18, 6, -8, -20), 172 Array(20, 6, -8, -20), 173 Array(20, 8, -8, -22), 174 Array(24, 8, -8, -24) 175 ) 176 177 // ge(x): w_truncated >= x 178 var ge = Map[Int, Bool]() 179 for (row <- qSelTable) { 180 for (k <- row) { 181 if (!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W))) 182 } 183 } 184 q_sel := MuxLookup(d_truncated, sel_0, 185 qSelTable.map(x => 186 MuxCase(sel_neg_dx2, Seq( 187 ge(x(0)) -> sel_dx2, 188 ge(x(1)) -> sel_d, 189 ge(x(2)) -> sel_0, 190 ge(x(3)) -> sel_neg_d 191 )) 192 ).zipWithIndex.map({ case (v, i) => i.U -> v }) 193 ) 194 195 /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder 196 * 197 * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d 198 */ 199 val csa = Module(new CSA3_2(wLen)) 200 csa.io.in(0) := ws 201 csa.io.in(1) := Cat(wc(wLen - 1, 2), wc_adj) 202 csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq( 203 sel_d -> neg_d, 204 sel_dx2 -> neg_dx2, 205 sel_neg_d -> d, 206 sel_neg_dx2 -> dx2 207 )) 208 ws_next := csa.io.out(0) 209 wc_next := csa.io.out(1) << 1 210 211 // On the fly quotient conversion 212 val q, qm = Reg(UInt(len.W)) 213 when(newReq) { 214 q := 0.U 215 qm := 0.U 216 }.elsewhen(state === s_recurrence) { 217 val qMap = Seq( 218 sel_0 -> (q, 0), 219 sel_d -> (q, 1), 220 sel_dx2 -> (q, 2), 221 sel_neg_d -> (qm, 3), 222 sel_neg_dx2 -> (qm, 2) 223 ) 224 q := MuxLookup(q_sel, 0.U, 225 qMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W))) 226 ) 227 val qmMap = Seq( 228 sel_0 -> (qm, 3), 229 sel_d -> (q, 0), 230 sel_dx2 -> (q, 1), 231 sel_neg_d -> (qm, 2), 232 sel_neg_dx2 -> (qm, 1) 233 ) 234 qm := MuxLookup(q_sel, 0.U, 235 qmMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W))) 236 ) 237 }.elsewhen(state === s_recovery_1) { 238 q := Mux(rem_temp(wLen - 1), qm, q) 239 } 240 241 242 val remainder = Mux(aSignReg, -ws(len - 1, 0), ws(len - 1, 0)) 243 val quotient = Mux(qSignReg, -q, q) 244 245 val res = Mux(isHi, 246 Mux(divZeroReg, ws(len - 1, 0), remainder), 247 Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient) 248 ) 249 io.out_data := Mux(isW, 250 SignExt(res(31, 0), len), 251 res 252 ) 253 io.in_ready := state === s_idle 254 io.out_valid := finished // state === s_finish 255} 256 257class SRT4Divider(len: Int)(implicit p: Parameters) extends AbstractDivider(len) { 258 259 val newReq = io.in.fire() 260 261 val uop = io.in.bits.uop 262 val uopReg = RegEnable(uop, newReq) 263 val ctrlReg = RegEnable(ctrl, newReq) 264 265 val divDataModule = Module(new SRT4DividerDataModule(len)) 266 267 val kill_w = uop.roqIdx.needFlush(io.redirectIn, io.flushIn) 268 val kill_r = !divDataModule.io.in_ready && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn) 269 270 divDataModule.io.src1 := io.in.bits.src(0) 271 divDataModule.io.src2 := io.in.bits.src(1) 272 divDataModule.io.valid := io.in.valid 273 divDataModule.io.sign := sign 274 divDataModule.io.kill_w := kill_w 275 divDataModule.io.kill_r := kill_r 276 divDataModule.io.isHi := ctrlReg.isHi 277 divDataModule.io.isW := ctrlReg.isW 278 divDataModule.io.out_ready := io.out.ready 279 280 io.in.ready := divDataModule.io.in_ready 281 io.out.valid := divDataModule.io.out_valid 282 io.out.bits.data := divDataModule.io.out_data 283 io.out.bits.uop := uopReg 284} 285