xref: /XiangShan/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala (revision 5018a303094b1186a1faf476806f7600fc33d0ba)
1*5018a303SLinJiaweipackage xiangshan.backend.fu
2*5018a303SLinJiawei
3*5018a303SLinJiaweiimport chisel3._
4*5018a303SLinJiaweiimport chisel3.util._
5*5018a303SLinJiaweiimport utils.SignExt
6*5018a303SLinJiaweiimport xiangshan.backend.fu.fpu.util.CSA3_2
7*5018a303SLinJiawei
8*5018a303SLinJiawei/** A Radix-4 SRT Integer Divider
9*5018a303SLinJiawei  *
10*5018a303SLinJiawei  * 2 ~ (5 + (len+3)/2) cycles are needed for each division.
11*5018a303SLinJiawei  */
12*5018a303SLinJiaweiclass SRT4Divider(len: Int) extends AbstractDivider(len) {
13*5018a303SLinJiawei
14*5018a303SLinJiawei  val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery :: s_finish :: Nil = Enum(6)
15*5018a303SLinJiawei  val state = RegInit(s_idle)
16*5018a303SLinJiawei  val newReq = (state === s_idle) && io.in.fire()
17*5018a303SLinJiawei  val cnt_next = Wire(UInt(log2Up((len+3)/2).W))
18*5018a303SLinJiawei  val cnt = RegEnable(cnt_next, state===s_normlize || state===s_recurrence)
19*5018a303SLinJiawei  val rec_enough = cnt_next === 0.U
20*5018a303SLinJiawei
21*5018a303SLinJiawei  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
22*5018a303SLinJiawei    val s = a(len - 1) && sign
23*5018a303SLinJiawei    (s, Mux(s, -a, a))
24*5018a303SLinJiawei  }
25*5018a303SLinJiawei  val (a, b) = (io.in.bits.src(0), io.in.bits.src(1))
26*5018a303SLinJiawei  val uop = io.in.bits.uop
27*5018a303SLinJiawei  val (aSign, aVal) = abs(a, sign)
28*5018a303SLinJiawei  val (bSign, bVal) = abs(b, sign)
29*5018a303SLinJiawei  val aSignReg = RegEnable(aSign, newReq)
30*5018a303SLinJiawei  val qSignReg = RegEnable(aSign ^ bSign, newReq)
31*5018a303SLinJiawei  val uopReg = RegEnable(uop, newReq)
32*5018a303SLinJiawei  val ctrlReg = RegEnable(ctrl, newReq)
33*5018a303SLinJiawei  val divZero = b === 0.U
34*5018a303SLinJiawei  val divZeroReg = RegEnable(divZero, newReq)
35*5018a303SLinJiawei
36*5018a303SLinJiawei  val kill = uopReg.roqIdx.needFlush(io.redirectIn)
37*5018a303SLinJiawei
38*5018a303SLinJiawei  switch(state){
39*5018a303SLinJiawei    is(s_idle){
40*5018a303SLinJiawei      when(io.in.fire()){ state := Mux(divZero, s_finish, s_lzd) }
41*5018a303SLinJiawei    }
42*5018a303SLinJiawei    is(s_lzd){ // leading zero detection
43*5018a303SLinJiawei      state := s_normlize
44*5018a303SLinJiawei    }
45*5018a303SLinJiawei    is(s_normlize){ // shift a/b
46*5018a303SLinJiawei      state := s_recurrence
47*5018a303SLinJiawei    }
48*5018a303SLinJiawei    is(s_recurrence){ // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
49*5018a303SLinJiawei      when(rec_enough){ state := s_recovery }
50*5018a303SLinJiawei    }
51*5018a303SLinJiawei    is(s_recovery){ // if rem < 0, rem = rem + d
52*5018a303SLinJiawei      state := s_finish
53*5018a303SLinJiawei    }
54*5018a303SLinJiawei    is(s_finish){
55*5018a303SLinJiawei      when(io.out.fire()){ state := s_idle }
56*5018a303SLinJiawei    }
57*5018a303SLinJiawei  }
58*5018a303SLinJiawei  when(kill){
59*5018a303SLinJiawei    state := s_idle
60*5018a303SLinJiawei  }
61*5018a303SLinJiawei
62*5018a303SLinJiawei  /** Calculate abs(a)/abs(b) by recurrence
63*5018a303SLinJiawei    *
64*5018a303SLinJiawei    * ws, wc: partial remainder in carry-save form,
65*5018a303SLinJiawei    *   in recurrence steps, ws/wc = 4ws[j]/4wc[j];
66*5018a303SLinJiawei    *   in recovery step, ws/wc = ws[j]/wc[j];
67*5018a303SLinJiawei    *   in final step, ws = abs(a)/abs(b).
68*5018a303SLinJiawei    *
69*5018a303SLinJiawei    * d: normlized divisor(1/2<=d<1)
70*5018a303SLinJiawei    *
71*5018a303SLinJiawei    * wLen = 3 integer bits + (len+1) frac bits
72*5018a303SLinJiawei    */
73*5018a303SLinJiawei  def wLen = 3 + len + 1
74*5018a303SLinJiawei  val ws, wc = Reg(UInt(wLen.W))
75*5018a303SLinJiawei  val ws_next, wc_next = Wire(UInt(wLen.W))
76*5018a303SLinJiawei  val d = Reg(UInt(wLen.W))
77*5018a303SLinJiawei
78*5018a303SLinJiawei  val aLeadingZeros = RegEnable(
79*5018a303SLinJiawei    next = PriorityEncoder(ws(len-1, 0).asBools().reverse),
80*5018a303SLinJiawei    enable = state===s_lzd
81*5018a303SLinJiawei  )
82*5018a303SLinJiawei  val bLeadingZeros = RegEnable(
83*5018a303SLinJiawei    next = PriorityEncoder(d(len-1, 0).asBools().reverse),
84*5018a303SLinJiawei    enable = state===s_lzd
85*5018a303SLinJiawei  )
86*5018a303SLinJiawei  val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt()
87*5018a303SLinJiawei  val isNegDiff = diff(diff.getWidth - 1)
88*5018a303SLinJiawei  val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt())
89*5018a303SLinJiawei  val qBitsIsOdd = quotientBits(0)
90*5018a303SLinJiawei  val recoveryShift = RegEnable(len.U - bLeadingZeros, state===s_normlize)
91*5018a303SLinJiawei  val a_shifted, b_shifted = Wire(UInt(len.W))
92*5018a303SLinJiawei  a_shifted := Mux(isNegDiff,
93*5018a303SLinJiawei    ws(len-1, 0) << bLeadingZeros,
94*5018a303SLinJiawei    ws(len-1, 0) << aLeadingZeros
95*5018a303SLinJiawei  )
96*5018a303SLinJiawei  b_shifted := d(len-1, 0) << bLeadingZeros
97*5018a303SLinJiawei
98*5018a303SLinJiawei  val rem_temp = ws + wc
99*5018a303SLinJiawei  val rem_fixed = Mux(rem_temp(wLen-1), rem_temp + d, rem_temp)
100*5018a303SLinJiawei  val rem_abs = (rem_fixed << recoveryShift)(2*len, len+1)
101*5018a303SLinJiawei
102*5018a303SLinJiawei  when(newReq){
103*5018a303SLinJiawei    ws := Cat(0.U(4.W), Mux(divZero, a, aVal))
104*5018a303SLinJiawei    wc := 0.U
105*5018a303SLinJiawei    d := Cat(0.U(4.W), bVal)
106*5018a303SLinJiawei  }.elsewhen(state === s_normlize){
107*5018a303SLinJiawei    d := Cat(0.U(3.W), b_shifted, 0.U(1.W))
108*5018a303SLinJiawei    ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1)
109*5018a303SLinJiawei  }.elsewhen(state === s_recurrence){
110*5018a303SLinJiawei    ws := Mux(rec_enough, ws_next, ws_next << 2)
111*5018a303SLinJiawei    wc := Mux(rec_enough, wc_next, wc_next << 2)
112*5018a303SLinJiawei  }.elsewhen(state === s_recovery){
113*5018a303SLinJiawei    ws := rem_abs
114*5018a303SLinJiawei  }
115*5018a303SLinJiawei
116*5018a303SLinJiawei  cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U)
117*5018a303SLinJiawei
118*5018a303SLinJiawei  /** Quotient selection
119*5018a303SLinJiawei    *
120*5018a303SLinJiawei    * the quotient selection table use truncated 7-bit remainder
121*5018a303SLinJiawei    * and 3-bit divisor
122*5018a303SLinJiawei    */
123*5018a303SLinJiawei  val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5)
124*5018a303SLinJiawei  val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W))
125*5018a303SLinJiawei  dx2 := d << 1
126*5018a303SLinJiawei  neg_d := (~d).asUInt() // add '1' in carry-save adder later
127*5018a303SLinJiawei  neg_dx2 := neg_d << 1
128*5018a303SLinJiawei
129*5018a303SLinJiawei  val q_sel = Wire(UInt(3.W))
130*5018a303SLinJiawei  val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq(
131*5018a303SLinJiawei    sel_d -> 1.U(2.W),
132*5018a303SLinJiawei    sel_dx2 -> 2.U(2.W)
133*5018a303SLinJiawei  ))
134*5018a303SLinJiawei
135*5018a303SLinJiawei  val w_truncated = (ws(wLen-1, wLen-1-6) + wc(wLen-1, wLen-1-6)).asSInt()
136*5018a303SLinJiawei  val d_truncated = d(len-1, len-3)
137*5018a303SLinJiawei
138*5018a303SLinJiawei  val qSelTable = Array(
139*5018a303SLinJiawei    Array(12, 4, -4, -13),
140*5018a303SLinJiawei    Array(14, 4, -6, -15),
141*5018a303SLinJiawei    Array(15, 4, -6, -16),
142*5018a303SLinJiawei    Array(16, 4, -6, -18),
143*5018a303SLinJiawei    Array(18, 6, -8, -20),
144*5018a303SLinJiawei    Array(20, 6, -8, -20),
145*5018a303SLinJiawei    Array(20, 8, -8, -22),
146*5018a303SLinJiawei    Array(24, 8, -8, -24)
147*5018a303SLinJiawei  )
148*5018a303SLinJiawei
149*5018a303SLinJiawei  // ge(x): w_truncated >= x
150*5018a303SLinJiawei  var ge = Map[Int, Bool]()
151*5018a303SLinJiawei  for(row <- qSelTable){
152*5018a303SLinJiawei    for(k <- row){
153*5018a303SLinJiawei      if(!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W)))
154*5018a303SLinJiawei    }
155*5018a303SLinJiawei  }
156*5018a303SLinJiawei  q_sel := MuxLookup(d_truncated, sel_0,
157*5018a303SLinJiawei    qSelTable.map(x =>
158*5018a303SLinJiawei      MuxCase(sel_neg_dx2, Seq(
159*5018a303SLinJiawei        ge(x(0)) -> sel_dx2,
160*5018a303SLinJiawei        ge(x(1)) -> sel_d,
161*5018a303SLinJiawei        ge(x(2)) -> sel_0,
162*5018a303SLinJiawei        ge(x(3)) -> sel_neg_d
163*5018a303SLinJiawei      ))
164*5018a303SLinJiawei    ).zipWithIndex.map({case(v, i) => i.U -> v})
165*5018a303SLinJiawei  )
166*5018a303SLinJiawei
167*5018a303SLinJiawei  /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder
168*5018a303SLinJiawei    *
169*5018a303SLinJiawei    * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
170*5018a303SLinJiawei    */
171*5018a303SLinJiawei  val csa = Module(new CSA3_2(wLen))
172*5018a303SLinJiawei  csa.io.in(0) := ws
173*5018a303SLinJiawei  csa.io.in(1) := Cat(wc(wLen-1, 2), wc_adj)
174*5018a303SLinJiawei  csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq(
175*5018a303SLinJiawei    sel_d -> neg_d,
176*5018a303SLinJiawei    sel_dx2 -> neg_dx2,
177*5018a303SLinJiawei    sel_neg_d -> d,
178*5018a303SLinJiawei    sel_neg_dx2 -> dx2
179*5018a303SLinJiawei  ))
180*5018a303SLinJiawei  ws_next := csa.io.out(0)
181*5018a303SLinJiawei  wc_next := csa.io.out(1) << 1
182*5018a303SLinJiawei
183*5018a303SLinJiawei  // On the fly quotient conversion
184*5018a303SLinJiawei  val q, qm = Reg(UInt(len.W))
185*5018a303SLinJiawei  when(newReq){
186*5018a303SLinJiawei    q := 0.U
187*5018a303SLinJiawei    qm := 0.U
188*5018a303SLinJiawei  }.elsewhen(state === s_recurrence){
189*5018a303SLinJiawei    val qMap = Seq(
190*5018a303SLinJiawei      sel_0 -> (q, 0),
191*5018a303SLinJiawei      sel_d -> (q, 1),
192*5018a303SLinJiawei      sel_dx2 -> (q, 2),
193*5018a303SLinJiawei      sel_neg_d -> (qm, 3),
194*5018a303SLinJiawei      sel_neg_dx2 -> (qm, 2)
195*5018a303SLinJiawei    )
196*5018a303SLinJiawei    q := MuxLookup(q_sel, 0.U,
197*5018a303SLinJiawei      qMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W)))
198*5018a303SLinJiawei    )
199*5018a303SLinJiawei    val qmMap = Seq(
200*5018a303SLinJiawei      sel_0 -> (qm, 3),
201*5018a303SLinJiawei      sel_d -> (q, 0),
202*5018a303SLinJiawei      sel_dx2 -> (q, 1),
203*5018a303SLinJiawei      sel_neg_d -> (qm, 2),
204*5018a303SLinJiawei      sel_neg_dx2 -> (qm, 1)
205*5018a303SLinJiawei    )
206*5018a303SLinJiawei    qm := MuxLookup(q_sel, 0.U,
207*5018a303SLinJiawei      qmMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W)))
208*5018a303SLinJiawei    )
209*5018a303SLinJiawei  }.elsewhen(state === s_recovery){
210*5018a303SLinJiawei    q := Mux(rem_temp(wLen-1), qm, q)
211*5018a303SLinJiawei  }
212*5018a303SLinJiawei
213*5018a303SLinJiawei
214*5018a303SLinJiawei  val remainder = Mux(aSignReg, -ws(len-1, 0), ws(len-1, 0))
215*5018a303SLinJiawei  val quotient = Mux(qSignReg, -q, q)
216*5018a303SLinJiawei
217*5018a303SLinJiawei  val res = Mux(ctrlReg.isHi,
218*5018a303SLinJiawei    Mux(divZeroReg, ws(len-1, 0), remainder),
219*5018a303SLinJiawei    Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient)
220*5018a303SLinJiawei  )
221*5018a303SLinJiawei
222*5018a303SLinJiawei  io.in.ready := state===s_idle
223*5018a303SLinJiawei  io.out.valid := state===s_finish && !kill
224*5018a303SLinJiawei  io.out.bits.data := Mux(ctrlReg.isW,
225*5018a303SLinJiawei    SignExt(res(31, 0), len),
226*5018a303SLinJiawei    res
227*5018a303SLinJiawei  )
228*5018a303SLinJiawei  io.out.bits.uop := uopReg
229*5018a303SLinJiawei
230*5018a303SLinJiawei}
231