xref: /XiangShan/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala (revision 2225d46ebbe2fd16b9b29963c27a7d0385a42709)
15018a303SLinJiaweipackage xiangshan.backend.fu
25018a303SLinJiawei
3*2225d46eSJiawei Linimport chipsalliance.rocketchip.config.Parameters
45018a303SLinJiaweiimport chisel3._
5f93cfde5SLinJiaweiimport chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage}
65018a303SLinJiaweiimport chisel3.util._
75018a303SLinJiaweiimport utils.SignExt
8afefbad5SLinJiaweiimport xiangshan.XSModule
97f1506e3SLinJiaweiimport xiangshan.backend.fu.util.CSA3_2
105018a303SLinJiawei
115018a303SLinJiawei/** A Radix-4 SRT Integer Divider
125018a303SLinJiawei  *
135018a303SLinJiawei  * 2 ~ (5 + (len+3)/2) cycles are needed for each division.
145018a303SLinJiawei  */
15afefbad5SLinJiaweiclass SRT4DividerDataModule(len: Int) extends Module {
16afefbad5SLinJiawei  val io = IO(new Bundle() {
17afefbad5SLinJiawei    val src1, src2 = Input(UInt(len.W))
18afefbad5SLinJiawei    val valid, sign, kill_w, kill_r, isHi, isW = Input(Bool())
19afefbad5SLinJiawei    val in_ready = Output(Bool())
20afefbad5SLinJiawei    val out_valid = Output(Bool())
21afefbad5SLinJiawei    val out_data = Output(UInt(len.W))
22afefbad5SLinJiawei    val out_ready = Input(Bool())
23afefbad5SLinJiawei  })
24afefbad5SLinJiawei
25afefbad5SLinJiawei  val (a, b, sign, valid, kill_w, kill_r, isHi, isW) =
26afefbad5SLinJiawei    (io.src1, io.src2, io.sign, io.valid, io.kill_w, io.kill_r, io.isHi, io.isW)
27afefbad5SLinJiawei  val in_fire = valid && io.in_ready
28afefbad5SLinJiawei  val out_fire = io.out_ready && io.out_valid
295018a303SLinJiawei
30c12bd822Sljw  // s_pad_* is not used
31c12bd822Sljw  val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_pad_1 :: s_pad_2 :: s_finish :: Nil = Enum(9)
32c12bd822Sljw  require(s_finish.litValue() == 8)
33c12bd822Sljw
345018a303SLinJiawei  val state = RegInit(s_idle)
35c12bd822Sljw  val finished = state(3).asBool // state === s_finish
36c12bd822Sljw
375018a303SLinJiawei  val cnt_next = Wire(UInt(log2Up((len + 3) / 2).W))
385018a303SLinJiawei  val cnt = RegEnable(cnt_next, state === s_normlize || state === s_recurrence)
395018a303SLinJiawei  val rec_enough = cnt_next === 0.U
40afefbad5SLinJiawei  val newReq = in_fire
415018a303SLinJiawei
425018a303SLinJiawei  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
435018a303SLinJiawei    val s = a(len - 1) && sign
445018a303SLinJiawei    (s, Mux(s, -a, a))
455018a303SLinJiawei  }
46afefbad5SLinJiawei
475018a303SLinJiawei  val (aSign, aVal) = abs(a, sign)
485018a303SLinJiawei  val (bSign, bVal) = abs(b, sign)
495018a303SLinJiawei  val aSignReg = RegEnable(aSign, newReq)
505018a303SLinJiawei  val qSignReg = RegEnable(aSign ^ bSign, newReq)
515018a303SLinJiawei  val divZero = b === 0.U
525018a303SLinJiawei  val divZeroReg = RegEnable(divZero, newReq)
535018a303SLinJiawei
545018a303SLinJiawei  switch(state) {
555018a303SLinJiawei    is(s_idle) {
56afefbad5SLinJiawei      when(in_fire && !kill_w) {
574680597eSYinan Xu        state := Mux(divZero, s_finish, s_lzd)
584680597eSYinan Xu      }
595018a303SLinJiawei    }
605018a303SLinJiawei    is(s_lzd) { // leading zero detection
615018a303SLinJiawei      state := s_normlize
625018a303SLinJiawei    }
635018a303SLinJiawei    is(s_normlize) { // shift a/b
645018a303SLinJiawei      state := s_recurrence
655018a303SLinJiawei    }
665018a303SLinJiawei    is(s_recurrence) { // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
67afefbad5SLinJiawei      when(rec_enough) {
68afefbad5SLinJiawei        state := s_recovery_1
69afefbad5SLinJiawei      }
705018a303SLinJiawei    }
71f93cfde5SLinJiawei    is(s_recovery_1) { // if rem < 0, rem = rem + d
72f93cfde5SLinJiawei      state := s_recovery_2
73f93cfde5SLinJiawei    }
74f93cfde5SLinJiawei    is(s_recovery_2) { // recovery shift
755018a303SLinJiawei      state := s_finish
765018a303SLinJiawei    }
775018a303SLinJiawei    is(s_finish) {
78afefbad5SLinJiawei      when(out_fire) {
79afefbad5SLinJiawei        state := s_idle
805018a303SLinJiawei      }
815018a303SLinJiawei    }
82afefbad5SLinJiawei  }
83afefbad5SLinJiawei  when(kill_r) {
845018a303SLinJiawei    state := s_idle
855018a303SLinJiawei  }
865018a303SLinJiawei
875018a303SLinJiawei  /** Calculate abs(a)/abs(b) by recurrence
885018a303SLinJiawei    *
895018a303SLinJiawei    * ws, wc: partial remainder in carry-save form,
905018a303SLinJiawei    * in recurrence steps, ws/wc = 4ws[j]/4wc[j];
915018a303SLinJiawei    * in recovery step, ws/wc = ws[j]/wc[j];
925018a303SLinJiawei    * in final step, ws = abs(a)/abs(b).
935018a303SLinJiawei    *
945018a303SLinJiawei    * d: normlized divisor(1/2<=d<1)
955018a303SLinJiawei    *
965018a303SLinJiawei    * wLen = 3 integer bits + (len+1) frac bits
975018a303SLinJiawei    */
985018a303SLinJiawei  def wLen = 3 + len + 1
99afefbad5SLinJiawei
1005018a303SLinJiawei  val ws, wc = Reg(UInt(wLen.W))
1015018a303SLinJiawei  val ws_next, wc_next = Wire(UInt(wLen.W))
1025018a303SLinJiawei  val d = Reg(UInt(wLen.W))
1035018a303SLinJiawei
1045018a303SLinJiawei  val aLeadingZeros = RegEnable(
1055018a303SLinJiawei    next = PriorityEncoder(ws(len - 1, 0).asBools().reverse),
1065018a303SLinJiawei    enable = state === s_lzd
1075018a303SLinJiawei  )
1085018a303SLinJiawei  val bLeadingZeros = RegEnable(
1095018a303SLinJiawei    next = PriorityEncoder(d(len - 1, 0).asBools().reverse),
1105018a303SLinJiawei    enable = state === s_lzd
1115018a303SLinJiawei  )
1125018a303SLinJiawei  val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt()
1135018a303SLinJiawei  val isNegDiff = diff(diff.getWidth - 1)
1145018a303SLinJiawei  val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt())
1155018a303SLinJiawei  val qBitsIsOdd = quotientBits(0)
1165018a303SLinJiawei  val recoveryShift = RegEnable(len.U - bLeadingZeros, state === s_normlize)
1175018a303SLinJiawei  val a_shifted, b_shifted = Wire(UInt(len.W))
1185018a303SLinJiawei  a_shifted := Mux(isNegDiff,
1195018a303SLinJiawei    ws(len - 1, 0) << bLeadingZeros,
1205018a303SLinJiawei    ws(len - 1, 0) << aLeadingZeros
1215018a303SLinJiawei  )
1225018a303SLinJiawei  b_shifted := d(len - 1, 0) << bLeadingZeros
1235018a303SLinJiawei
1245018a303SLinJiawei  val rem_temp = ws + wc
1255018a303SLinJiawei  val rem_fixed = Mux(rem_temp(wLen - 1), rem_temp + d, rem_temp)
126f93cfde5SLinJiawei  val rem_abs = (ws << recoveryShift) (2 * len, len + 1)
1275018a303SLinJiawei
1285018a303SLinJiawei  when(newReq) {
1295018a303SLinJiawei    ws := Cat(0.U(4.W), Mux(divZero, a, aVal))
1305018a303SLinJiawei    wc := 0.U
1315018a303SLinJiawei    d := Cat(0.U(4.W), bVal)
1325018a303SLinJiawei  }.elsewhen(state === s_normlize) {
1335018a303SLinJiawei    d := Cat(0.U(3.W), b_shifted, 0.U(1.W))
1345018a303SLinJiawei    ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1)
1355018a303SLinJiawei  }.elsewhen(state === s_recurrence) {
1365018a303SLinJiawei    ws := Mux(rec_enough, ws_next, ws_next << 2)
1375018a303SLinJiawei    wc := Mux(rec_enough, wc_next, wc_next << 2)
138f93cfde5SLinJiawei  }.elsewhen(state === s_recovery_1) {
139f93cfde5SLinJiawei    ws := rem_fixed
140f93cfde5SLinJiawei  }.elsewhen(state === s_recovery_2) {
1415018a303SLinJiawei    ws := rem_abs
1425018a303SLinJiawei  }
1435018a303SLinJiawei
1445018a303SLinJiawei  cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U)
1455018a303SLinJiawei
1465018a303SLinJiawei  /** Quotient selection
1475018a303SLinJiawei    *
1485018a303SLinJiawei    * the quotient selection table use truncated 7-bit remainder
1495018a303SLinJiawei    * and 3-bit divisor
1505018a303SLinJiawei    */
1515018a303SLinJiawei  val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5)
1525018a303SLinJiawei  val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W))
1535018a303SLinJiawei  dx2 := d << 1
1545018a303SLinJiawei  neg_d := (~d).asUInt() // add '1' in carry-save adder later
1555018a303SLinJiawei  neg_dx2 := neg_d << 1
1565018a303SLinJiawei
1575018a303SLinJiawei  val q_sel = Wire(UInt(3.W))
1585018a303SLinJiawei  val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq(
1595018a303SLinJiawei    sel_d -> 1.U(2.W),
1605018a303SLinJiawei    sel_dx2 -> 2.U(2.W)
1615018a303SLinJiawei  ))
1625018a303SLinJiawei
1635018a303SLinJiawei  val w_truncated = (ws(wLen - 1, wLen - 1 - 6) + wc(wLen - 1, wLen - 1 - 6)).asSInt()
1645018a303SLinJiawei  val d_truncated = d(len - 1, len - 3)
1655018a303SLinJiawei
1665018a303SLinJiawei  val qSelTable = Array(
1675018a303SLinJiawei    Array(12, 4, -4, -13),
1685018a303SLinJiawei    Array(14, 4, -6, -15),
1695018a303SLinJiawei    Array(15, 4, -6, -16),
1705018a303SLinJiawei    Array(16, 4, -6, -18),
1715018a303SLinJiawei    Array(18, 6, -8, -20),
1725018a303SLinJiawei    Array(20, 6, -8, -20),
1735018a303SLinJiawei    Array(20, 8, -8, -22),
1745018a303SLinJiawei    Array(24, 8, -8, -24)
1755018a303SLinJiawei  )
1765018a303SLinJiawei
1775018a303SLinJiawei  // ge(x): w_truncated >= x
1785018a303SLinJiawei  var ge = Map[Int, Bool]()
1795018a303SLinJiawei  for (row <- qSelTable) {
1805018a303SLinJiawei    for (k <- row) {
1815018a303SLinJiawei      if (!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W)))
1825018a303SLinJiawei    }
1835018a303SLinJiawei  }
1845018a303SLinJiawei  q_sel := MuxLookup(d_truncated, sel_0,
1855018a303SLinJiawei    qSelTable.map(x =>
1865018a303SLinJiawei      MuxCase(sel_neg_dx2, Seq(
1875018a303SLinJiawei        ge(x(0)) -> sel_dx2,
1885018a303SLinJiawei        ge(x(1)) -> sel_d,
1895018a303SLinJiawei        ge(x(2)) -> sel_0,
1905018a303SLinJiawei        ge(x(3)) -> sel_neg_d
1915018a303SLinJiawei      ))
1925018a303SLinJiawei    ).zipWithIndex.map({ case (v, i) => i.U -> v })
1935018a303SLinJiawei  )
1945018a303SLinJiawei
1955018a303SLinJiawei  /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder
1965018a303SLinJiawei    *
1975018a303SLinJiawei    * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
1985018a303SLinJiawei    */
1995018a303SLinJiawei  val csa = Module(new CSA3_2(wLen))
2005018a303SLinJiawei  csa.io.in(0) := ws
2015018a303SLinJiawei  csa.io.in(1) := Cat(wc(wLen - 1, 2), wc_adj)
2025018a303SLinJiawei  csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq(
2035018a303SLinJiawei    sel_d -> neg_d,
2045018a303SLinJiawei    sel_dx2 -> neg_dx2,
2055018a303SLinJiawei    sel_neg_d -> d,
2065018a303SLinJiawei    sel_neg_dx2 -> dx2
2075018a303SLinJiawei  ))
2085018a303SLinJiawei  ws_next := csa.io.out(0)
2095018a303SLinJiawei  wc_next := csa.io.out(1) << 1
2105018a303SLinJiawei
2115018a303SLinJiawei  // On the fly quotient conversion
2125018a303SLinJiawei  val q, qm = Reg(UInt(len.W))
2135018a303SLinJiawei  when(newReq) {
2145018a303SLinJiawei    q := 0.U
2155018a303SLinJiawei    qm := 0.U
2165018a303SLinJiawei  }.elsewhen(state === s_recurrence) {
2175018a303SLinJiawei    val qMap = Seq(
2185018a303SLinJiawei      sel_0 -> (q, 0),
2195018a303SLinJiawei      sel_d -> (q, 1),
2205018a303SLinJiawei      sel_dx2 -> (q, 2),
2215018a303SLinJiawei      sel_neg_d -> (qm, 3),
2225018a303SLinJiawei      sel_neg_dx2 -> (qm, 2)
2235018a303SLinJiawei    )
2245018a303SLinJiawei    q := MuxLookup(q_sel, 0.U,
2255018a303SLinJiawei      qMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
2265018a303SLinJiawei    )
2275018a303SLinJiawei    val qmMap = Seq(
2285018a303SLinJiawei      sel_0 -> (qm, 3),
2295018a303SLinJiawei      sel_d -> (q, 0),
2305018a303SLinJiawei      sel_dx2 -> (q, 1),
2315018a303SLinJiawei      sel_neg_d -> (qm, 2),
2325018a303SLinJiawei      sel_neg_dx2 -> (qm, 1)
2335018a303SLinJiawei    )
2345018a303SLinJiawei    qm := MuxLookup(q_sel, 0.U,
2355018a303SLinJiawei      qmMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
2365018a303SLinJiawei    )
237f93cfde5SLinJiawei  }.elsewhen(state === s_recovery_1) {
2385018a303SLinJiawei    q := Mux(rem_temp(wLen - 1), qm, q)
2395018a303SLinJiawei  }
2405018a303SLinJiawei
2415018a303SLinJiawei
2425018a303SLinJiawei  val remainder = Mux(aSignReg, -ws(len - 1, 0), ws(len - 1, 0))
2435018a303SLinJiawei  val quotient = Mux(qSignReg, -q, q)
2445018a303SLinJiawei
245afefbad5SLinJiawei  val res = Mux(isHi,
2465018a303SLinJiawei    Mux(divZeroReg, ws(len - 1, 0), remainder),
2475018a303SLinJiawei    Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient)
2485018a303SLinJiawei  )
249afefbad5SLinJiawei  io.out_data := Mux(isW,
2505018a303SLinJiawei    SignExt(res(31, 0), len),
2515018a303SLinJiawei    res
2525018a303SLinJiawei  )
253afefbad5SLinJiawei  io.in_ready := state === s_idle
254c12bd822Sljw  io.out_valid := finished // state === s_finish
255afefbad5SLinJiawei}
2565018a303SLinJiawei
257*2225d46eSJiawei Linclass SRT4Divider(len: Int)(implicit p: Parameters) extends AbstractDivider(len) {
258afefbad5SLinJiawei
259afefbad5SLinJiawei  val newReq = io.in.fire()
260afefbad5SLinJiawei
261afefbad5SLinJiawei  val uop = io.in.bits.uop
262afefbad5SLinJiawei  val uopReg = RegEnable(uop, newReq)
263afefbad5SLinJiawei  val ctrlReg = RegEnable(ctrl, newReq)
264afefbad5SLinJiawei
265afefbad5SLinJiawei  val divDataModule = Module(new SRT4DividerDataModule(len))
266afefbad5SLinJiawei
267afefbad5SLinJiawei  val kill_w = uop.roqIdx.needFlush(io.redirectIn, io.flushIn)
268afefbad5SLinJiawei  val kill_r = !divDataModule.io.in_ready && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn)
269afefbad5SLinJiawei
270afefbad5SLinJiawei  divDataModule.io.src1 := io.in.bits.src(0)
271afefbad5SLinJiawei  divDataModule.io.src2 := io.in.bits.src(1)
272afefbad5SLinJiawei  divDataModule.io.valid := io.in.valid
273afefbad5SLinJiawei  divDataModule.io.sign := sign
274afefbad5SLinJiawei  divDataModule.io.kill_w := kill_w
275afefbad5SLinJiawei  divDataModule.io.kill_r := kill_r
276afefbad5SLinJiawei  divDataModule.io.isHi := ctrlReg.isHi
277afefbad5SLinJiawei  divDataModule.io.isW := ctrlReg.isW
278afefbad5SLinJiawei  divDataModule.io.out_ready := io.out.ready
279afefbad5SLinJiawei
280afefbad5SLinJiawei  io.in.ready := divDataModule.io.in_ready
281afefbad5SLinJiawei  io.out.valid := divDataModule.io.out_valid
282afefbad5SLinJiawei  io.out.bits.data := divDataModule.io.out_data
283afefbad5SLinJiawei  io.out.bits.uop := uopReg
2845018a303SLinJiawei}
285