xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Radix2Divider.scala (revision 2d7c7105479bec3c329cf213502bd6a01cff7c0a)
1package xiangshan.backend.fu
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import utils._
7
8abstract class AbstractDivider(len: Int) extends FunctionUnit(len){
9  val ctrl = IO(Input(new MulDivCtrl))
10  val sign = ctrl.sign
11}
12
13class Radix2Divider(len: Int) extends AbstractDivider(len) {
14
15  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
16    val s = a(len - 1) && sign
17    (s, Mux(s, -a, a))
18  }
19
20  val s_idle :: s_log2 :: s_shift :: s_compute :: s_finish :: Nil = Enum(5)
21  val state = RegInit(s_idle)
22  val newReq = (state === s_idle) && io.in.fire()
23
24  val (a, b) = (io.in.bits.src(0), io.in.bits.src(1))
25  val divBy0 = b === 0.U(len.W)
26  val divBy0Reg = RegEnable(divBy0, newReq)
27
28  val shiftReg = Reg(UInt((1 + len * 2).W))
29  val hi = shiftReg(len * 2, len)
30  val lo = shiftReg(len - 1, 0)
31
32  val uop = io.in.bits.uop
33
34  val (aSign, aVal) = abs(a, sign)
35  val (bSign, bVal) = abs(b, sign)
36  val aSignReg = RegEnable(aSign, newReq)
37  val qSignReg = RegEnable((aSign ^ bSign) && !divBy0, newReq)
38  val bReg = RegEnable(bVal, newReq)
39  val aValx2Reg = RegEnable(Cat(aVal, "b0".U), newReq)
40  val ctrlReg = RegEnable(ctrl, newReq)
41  val uopReg = RegEnable(uop, newReq)
42
43  val cnt = Counter(len)
44  when (newReq) {
45    state := s_log2
46  } .elsewhen (state === s_log2) {
47    // `canSkipShift` is calculated as following:
48    //   bEffectiveBit = Log2(bVal, XLEN) + 1.U
49    //   aLeadingZero = 64.U - aEffectiveBit = 64.U - (Log2(aVal, XLEN) + 1.U)
50    //   canSkipShift = aLeadingZero + bEffectiveBit
51    //     = 64.U - (Log2(aVal, XLEN) + 1.U) + Log2(bVal, XLEN) + 1.U
52    //     = 64.U + Log2(bVal, XLEN) - Log2(aVal, XLEN)
53    //     = (64.U | Log2(bVal, XLEN)) - Log2(aVal, XLEN)  // since Log2(bVal, XLEN) < 64.U
54    val canSkipShift = (64.U | Log2(bReg)) - Log2(aValx2Reg)
55    // When divide by 0, the quotient should be all 1's.
56    // Therefore we can not shift in 0s here.
57    // We do not skip any shift to avoid this.
58    cnt.value := Mux(divBy0Reg, 0.U, Mux(canSkipShift >= (len-1).U, (len-1).U, canSkipShift))
59    state := s_shift
60  } .elsewhen (state === s_shift) {
61    shiftReg := aValx2Reg << cnt.value
62    state := s_compute
63  } .elsewhen (state === s_compute) {
64    val enough = hi.asUInt >= bReg.asUInt
65    shiftReg := Cat(Mux(enough, hi - bReg, hi)(len - 1, 0), lo, enough)
66    cnt.inc()
67    when (cnt.value === (len-1).U) { state := s_finish }
68  } .elsewhen (state === s_finish) {
69    when(io.out.ready){
70      state := s_idle
71    }
72  }
73
74  val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn)
75  when(kill){
76    state := s_idle
77  }
78
79  val r = hi(len, 1)
80  val resQ = Mux(qSignReg, -lo, lo)
81  val resR = Mux(aSignReg, -r, r)
82
83  val xlen = io.out.bits.data.getWidth
84  val res = Mux(ctrlReg.isHi, resR, resQ)
85  io.out.bits.data := Mux(ctrlReg.isW, SignExt(res(31,0),xlen), res)
86  io.out.bits.uop := uopReg
87
88  io.out.valid := state === s_finish
89  io.in.ready := state === s_idle
90}