xref: /XiangShan/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala (revision d8a66f7ecacc2b3538a1083b0716b7066c77ba3f)
1package xiangshan.backend.fu
2
3import chisel3._
4import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage}
5import chisel3.util._
6import utils.SignExt
7import xiangshan.backend.fu.util.CSA3_2
8
9/** A Radix-4 SRT Integer Divider
10  *
11  * 2 ~ (5 + (len+3)/2) cycles are needed for each division.
12  */
13class SRT4Divider(len: Int) extends AbstractDivider(len) {
14
15  val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_finish :: Nil = Enum(7)
16  val state = RegInit(s_idle)
17  val newReq = (state === s_idle) && io.in.fire()
18  val cnt_next = Wire(UInt(log2Up((len+3)/2).W))
19  val cnt = RegEnable(cnt_next, state===s_normlize || state===s_recurrence)
20  val rec_enough = cnt_next === 0.U
21
22  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
23    val s = a(len - 1) && sign
24    (s, Mux(s, -a, a))
25  }
26  val (a, b) = (io.in.bits.src(0), io.in.bits.src(1))
27  val uop = io.in.bits.uop
28  val (aSign, aVal) = abs(a, sign)
29  val (bSign, bVal) = abs(b, sign)
30  val aSignReg = RegEnable(aSign, newReq)
31  val qSignReg = RegEnable(aSign ^ bSign, newReq)
32  val uopReg = RegEnable(uop, newReq)
33  val ctrlReg = RegEnable(ctrl, newReq)
34  val divZero = b === 0.U
35  val divZeroReg = RegEnable(divZero, newReq)
36
37  val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn)
38
39  switch(state){
40    is(s_idle){
41      when (io.in.fire() && !io.in.bits.uop.roqIdx.needFlush(io.redirectIn, io.flushIn)) {
42        state := Mux(divZero, s_finish, s_lzd)
43      }
44    }
45    is(s_lzd){ // leading zero detection
46      state := s_normlize
47    }
48    is(s_normlize){ // shift a/b
49      state := s_recurrence
50    }
51    is(s_recurrence){ // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
52      when(rec_enough){ state := s_recovery_1 }
53    }
54    is(s_recovery_1){ // if rem < 0, rem = rem + d
55      state := s_recovery_2
56    }
57    is(s_recovery_2){ // recovery shift
58      state := s_finish
59    }
60    is(s_finish){
61      when(io.out.fire()){ state := s_idle }
62    }
63  }
64  when(kill){
65    state := s_idle
66  }
67
68  /** Calculate abs(a)/abs(b) by recurrence
69    *
70    * ws, wc: partial remainder in carry-save form,
71    *   in recurrence steps, ws/wc = 4ws[j]/4wc[j];
72    *   in recovery step, ws/wc = ws[j]/wc[j];
73    *   in final step, ws = abs(a)/abs(b).
74    *
75    * d: normlized divisor(1/2<=d<1)
76    *
77    * wLen = 3 integer bits + (len+1) frac bits
78    */
79  def wLen = 3 + len + 1
80  val ws, wc = Reg(UInt(wLen.W))
81  val ws_next, wc_next = Wire(UInt(wLen.W))
82  val d = Reg(UInt(wLen.W))
83
84  val aLeadingZeros = RegEnable(
85    next = PriorityEncoder(ws(len-1, 0).asBools().reverse),
86    enable = state===s_lzd
87  )
88  val bLeadingZeros = RegEnable(
89    next = PriorityEncoder(d(len-1, 0).asBools().reverse),
90    enable = state===s_lzd
91  )
92  val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt()
93  val isNegDiff = diff(diff.getWidth - 1)
94  val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt())
95  val qBitsIsOdd = quotientBits(0)
96  val recoveryShift = RegEnable(len.U - bLeadingZeros, state===s_normlize)
97  val a_shifted, b_shifted = Wire(UInt(len.W))
98  a_shifted := Mux(isNegDiff,
99    ws(len-1, 0) << bLeadingZeros,
100    ws(len-1, 0) << aLeadingZeros
101  )
102  b_shifted := d(len-1, 0) << bLeadingZeros
103
104  val rem_temp = ws + wc
105  val rem_fixed = Mux(rem_temp(wLen-1), rem_temp + d, rem_temp)
106  val rem_abs = (ws << recoveryShift)(2*len, len+1)
107
108  when(newReq){
109    ws := Cat(0.U(4.W), Mux(divZero, a, aVal))
110    wc := 0.U
111    d := Cat(0.U(4.W), bVal)
112  }.elsewhen(state === s_normlize){
113    d := Cat(0.U(3.W), b_shifted, 0.U(1.W))
114    ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1)
115  }.elsewhen(state === s_recurrence){
116    ws := Mux(rec_enough, ws_next, ws_next << 2)
117    wc := Mux(rec_enough, wc_next, wc_next << 2)
118  }.elsewhen(state === s_recovery_1){
119    ws := rem_fixed
120  }.elsewhen(state === s_recovery_2){
121    ws := rem_abs
122  }
123
124  cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U)
125
126  /** Quotient selection
127    *
128    * the quotient selection table use truncated 7-bit remainder
129    * and 3-bit divisor
130    */
131  val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5)
132  val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W))
133  dx2 := d << 1
134  neg_d := (~d).asUInt() // add '1' in carry-save adder later
135  neg_dx2 := neg_d << 1
136
137  val q_sel = Wire(UInt(3.W))
138  val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq(
139    sel_d -> 1.U(2.W),
140    sel_dx2 -> 2.U(2.W)
141  ))
142
143  val w_truncated = (ws(wLen-1, wLen-1-6) + wc(wLen-1, wLen-1-6)).asSInt()
144  val d_truncated = d(len-1, len-3)
145
146  val qSelTable = Array(
147    Array(12, 4, -4, -13),
148    Array(14, 4, -6, -15),
149    Array(15, 4, -6, -16),
150    Array(16, 4, -6, -18),
151    Array(18, 6, -8, -20),
152    Array(20, 6, -8, -20),
153    Array(20, 8, -8, -22),
154    Array(24, 8, -8, -24)
155  )
156
157  // ge(x): w_truncated >= x
158  var ge = Map[Int, Bool]()
159  for(row <- qSelTable){
160    for(k <- row){
161      if(!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W)))
162    }
163  }
164  q_sel := MuxLookup(d_truncated, sel_0,
165    qSelTable.map(x =>
166      MuxCase(sel_neg_dx2, Seq(
167        ge(x(0)) -> sel_dx2,
168        ge(x(1)) -> sel_d,
169        ge(x(2)) -> sel_0,
170        ge(x(3)) -> sel_neg_d
171      ))
172    ).zipWithIndex.map({case(v, i) => i.U -> v})
173  )
174
175  /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder
176    *
177    * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
178    */
179  val csa = Module(new CSA3_2(wLen))
180  csa.io.in(0) := ws
181  csa.io.in(1) := Cat(wc(wLen-1, 2), wc_adj)
182  csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq(
183    sel_d -> neg_d,
184    sel_dx2 -> neg_dx2,
185    sel_neg_d -> d,
186    sel_neg_dx2 -> dx2
187  ))
188  ws_next := csa.io.out(0)
189  wc_next := csa.io.out(1) << 1
190
191  // On the fly quotient conversion
192  val q, qm = Reg(UInt(len.W))
193  when(newReq){
194    q := 0.U
195    qm := 0.U
196  }.elsewhen(state === s_recurrence){
197    val qMap = Seq(
198      sel_0 -> (q, 0),
199      sel_d -> (q, 1),
200      sel_dx2 -> (q, 2),
201      sel_neg_d -> (qm, 3),
202      sel_neg_dx2 -> (qm, 2)
203    )
204    q := MuxLookup(q_sel, 0.U,
205      qMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W)))
206    )
207    val qmMap = Seq(
208      sel_0 -> (qm, 3),
209      sel_d -> (q, 0),
210      sel_dx2 -> (q, 1),
211      sel_neg_d -> (qm, 2),
212      sel_neg_dx2 -> (qm, 1)
213    )
214    qm := MuxLookup(q_sel, 0.U,
215      qmMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W)))
216    )
217  }.elsewhen(state === s_recovery_1){
218    q := Mux(rem_temp(wLen-1), qm, q)
219  }
220
221
222  val remainder = Mux(aSignReg, -ws(len-1, 0), ws(len-1, 0))
223  val quotient = Mux(qSignReg, -q, q)
224
225  val res = Mux(ctrlReg.isHi,
226    Mux(divZeroReg, ws(len-1, 0), remainder),
227    Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient)
228  )
229
230  io.in.ready := state===s_idle
231  io.out.valid := state===s_finish
232  io.out.bits.data := Mux(ctrlReg.isW,
233    SignExt(res(31, 0), len),
234    res
235  )
236  io.out.bits.uop := uopReg
237
238}
239