xref: /XiangShan/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala (revision 2225d46ebbe2fd16b9b29963c27a7d0385a42709)
1package xiangshan.backend.fu
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage}
6import chisel3.util._
7import utils.SignExt
8import xiangshan.XSModule
9import xiangshan.backend.fu.util.CSA3_2
10
11/** A Radix-4 SRT Integer Divider
12  *
13  * 2 ~ (5 + (len+3)/2) cycles are needed for each division.
14  */
15class SRT4DividerDataModule(len: Int) extends Module {
16  val io = IO(new Bundle() {
17    val src1, src2 = Input(UInt(len.W))
18    val valid, sign, kill_w, kill_r, isHi, isW = Input(Bool())
19    val in_ready = Output(Bool())
20    val out_valid = Output(Bool())
21    val out_data = Output(UInt(len.W))
22    val out_ready = Input(Bool())
23  })
24
25  val (a, b, sign, valid, kill_w, kill_r, isHi, isW) =
26    (io.src1, io.src2, io.sign, io.valid, io.kill_w, io.kill_r, io.isHi, io.isW)
27  val in_fire = valid && io.in_ready
28  val out_fire = io.out_ready && io.out_valid
29
30  // s_pad_* is not used
31  val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_pad_1 :: s_pad_2 :: s_finish :: Nil = Enum(9)
32  require(s_finish.litValue() == 8)
33
34  val state = RegInit(s_idle)
35  val finished = state(3).asBool // state === s_finish
36
37  val cnt_next = Wire(UInt(log2Up((len + 3) / 2).W))
38  val cnt = RegEnable(cnt_next, state === s_normlize || state === s_recurrence)
39  val rec_enough = cnt_next === 0.U
40  val newReq = in_fire
41
42  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
43    val s = a(len - 1) && sign
44    (s, Mux(s, -a, a))
45  }
46
47  val (aSign, aVal) = abs(a, sign)
48  val (bSign, bVal) = abs(b, sign)
49  val aSignReg = RegEnable(aSign, newReq)
50  val qSignReg = RegEnable(aSign ^ bSign, newReq)
51  val divZero = b === 0.U
52  val divZeroReg = RegEnable(divZero, newReq)
53
54  switch(state) {
55    is(s_idle) {
56      when(in_fire && !kill_w) {
57        state := Mux(divZero, s_finish, s_lzd)
58      }
59    }
60    is(s_lzd) { // leading zero detection
61      state := s_normlize
62    }
63    is(s_normlize) { // shift a/b
64      state := s_recurrence
65    }
66    is(s_recurrence) { // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
67      when(rec_enough) {
68        state := s_recovery_1
69      }
70    }
71    is(s_recovery_1) { // if rem < 0, rem = rem + d
72      state := s_recovery_2
73    }
74    is(s_recovery_2) { // recovery shift
75      state := s_finish
76    }
77    is(s_finish) {
78      when(out_fire) {
79        state := s_idle
80      }
81    }
82  }
83  when(kill_r) {
84    state := s_idle
85  }
86
87  /** Calculate abs(a)/abs(b) by recurrence
88    *
89    * ws, wc: partial remainder in carry-save form,
90    * in recurrence steps, ws/wc = 4ws[j]/4wc[j];
91    * in recovery step, ws/wc = ws[j]/wc[j];
92    * in final step, ws = abs(a)/abs(b).
93    *
94    * d: normlized divisor(1/2<=d<1)
95    *
96    * wLen = 3 integer bits + (len+1) frac bits
97    */
98  def wLen = 3 + len + 1
99
100  val ws, wc = Reg(UInt(wLen.W))
101  val ws_next, wc_next = Wire(UInt(wLen.W))
102  val d = Reg(UInt(wLen.W))
103
104  val aLeadingZeros = RegEnable(
105    next = PriorityEncoder(ws(len - 1, 0).asBools().reverse),
106    enable = state === s_lzd
107  )
108  val bLeadingZeros = RegEnable(
109    next = PriorityEncoder(d(len - 1, 0).asBools().reverse),
110    enable = state === s_lzd
111  )
112  val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt()
113  val isNegDiff = diff(diff.getWidth - 1)
114  val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt())
115  val qBitsIsOdd = quotientBits(0)
116  val recoveryShift = RegEnable(len.U - bLeadingZeros, state === s_normlize)
117  val a_shifted, b_shifted = Wire(UInt(len.W))
118  a_shifted := Mux(isNegDiff,
119    ws(len - 1, 0) << bLeadingZeros,
120    ws(len - 1, 0) << aLeadingZeros
121  )
122  b_shifted := d(len - 1, 0) << bLeadingZeros
123
124  val rem_temp = ws + wc
125  val rem_fixed = Mux(rem_temp(wLen - 1), rem_temp + d, rem_temp)
126  val rem_abs = (ws << recoveryShift) (2 * len, len + 1)
127
128  when(newReq) {
129    ws := Cat(0.U(4.W), Mux(divZero, a, aVal))
130    wc := 0.U
131    d := Cat(0.U(4.W), bVal)
132  }.elsewhen(state === s_normlize) {
133    d := Cat(0.U(3.W), b_shifted, 0.U(1.W))
134    ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1)
135  }.elsewhen(state === s_recurrence) {
136    ws := Mux(rec_enough, ws_next, ws_next << 2)
137    wc := Mux(rec_enough, wc_next, wc_next << 2)
138  }.elsewhen(state === s_recovery_1) {
139    ws := rem_fixed
140  }.elsewhen(state === s_recovery_2) {
141    ws := rem_abs
142  }
143
144  cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U)
145
146  /** Quotient selection
147    *
148    * the quotient selection table use truncated 7-bit remainder
149    * and 3-bit divisor
150    */
151  val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5)
152  val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W))
153  dx2 := d << 1
154  neg_d := (~d).asUInt() // add '1' in carry-save adder later
155  neg_dx2 := neg_d << 1
156
157  val q_sel = Wire(UInt(3.W))
158  val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq(
159    sel_d -> 1.U(2.W),
160    sel_dx2 -> 2.U(2.W)
161  ))
162
163  val w_truncated = (ws(wLen - 1, wLen - 1 - 6) + wc(wLen - 1, wLen - 1 - 6)).asSInt()
164  val d_truncated = d(len - 1, len - 3)
165
166  val qSelTable = Array(
167    Array(12, 4, -4, -13),
168    Array(14, 4, -6, -15),
169    Array(15, 4, -6, -16),
170    Array(16, 4, -6, -18),
171    Array(18, 6, -8, -20),
172    Array(20, 6, -8, -20),
173    Array(20, 8, -8, -22),
174    Array(24, 8, -8, -24)
175  )
176
177  // ge(x): w_truncated >= x
178  var ge = Map[Int, Bool]()
179  for (row <- qSelTable) {
180    for (k <- row) {
181      if (!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W)))
182    }
183  }
184  q_sel := MuxLookup(d_truncated, sel_0,
185    qSelTable.map(x =>
186      MuxCase(sel_neg_dx2, Seq(
187        ge(x(0)) -> sel_dx2,
188        ge(x(1)) -> sel_d,
189        ge(x(2)) -> sel_0,
190        ge(x(3)) -> sel_neg_d
191      ))
192    ).zipWithIndex.map({ case (v, i) => i.U -> v })
193  )
194
195  /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder
196    *
197    * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
198    */
199  val csa = Module(new CSA3_2(wLen))
200  csa.io.in(0) := ws
201  csa.io.in(1) := Cat(wc(wLen - 1, 2), wc_adj)
202  csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq(
203    sel_d -> neg_d,
204    sel_dx2 -> neg_dx2,
205    sel_neg_d -> d,
206    sel_neg_dx2 -> dx2
207  ))
208  ws_next := csa.io.out(0)
209  wc_next := csa.io.out(1) << 1
210
211  // On the fly quotient conversion
212  val q, qm = Reg(UInt(len.W))
213  when(newReq) {
214    q := 0.U
215    qm := 0.U
216  }.elsewhen(state === s_recurrence) {
217    val qMap = Seq(
218      sel_0 -> (q, 0),
219      sel_d -> (q, 1),
220      sel_dx2 -> (q, 2),
221      sel_neg_d -> (qm, 3),
222      sel_neg_dx2 -> (qm, 2)
223    )
224    q := MuxLookup(q_sel, 0.U,
225      qMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
226    )
227    val qmMap = Seq(
228      sel_0 -> (qm, 3),
229      sel_d -> (q, 0),
230      sel_dx2 -> (q, 1),
231      sel_neg_d -> (qm, 2),
232      sel_neg_dx2 -> (qm, 1)
233    )
234    qm := MuxLookup(q_sel, 0.U,
235      qmMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
236    )
237  }.elsewhen(state === s_recovery_1) {
238    q := Mux(rem_temp(wLen - 1), qm, q)
239  }
240
241
242  val remainder = Mux(aSignReg, -ws(len - 1, 0), ws(len - 1, 0))
243  val quotient = Mux(qSignReg, -q, q)
244
245  val res = Mux(isHi,
246    Mux(divZeroReg, ws(len - 1, 0), remainder),
247    Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient)
248  )
249  io.out_data := Mux(isW,
250    SignExt(res(31, 0), len),
251    res
252  )
253  io.in_ready := state === s_idle
254  io.out_valid := finished // state === s_finish
255}
256
257class SRT4Divider(len: Int)(implicit p: Parameters) extends AbstractDivider(len) {
258
259  val newReq = io.in.fire()
260
261  val uop = io.in.bits.uop
262  val uopReg = RegEnable(uop, newReq)
263  val ctrlReg = RegEnable(ctrl, newReq)
264
265  val divDataModule = Module(new SRT4DividerDataModule(len))
266
267  val kill_w = uop.roqIdx.needFlush(io.redirectIn, io.flushIn)
268  val kill_r = !divDataModule.io.in_ready && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn)
269
270  divDataModule.io.src1 := io.in.bits.src(0)
271  divDataModule.io.src2 := io.in.bits.src(1)
272  divDataModule.io.valid := io.in.valid
273  divDataModule.io.sign := sign
274  divDataModule.io.kill_w := kill_w
275  divDataModule.io.kill_r := kill_r
276  divDataModule.io.isHi := ctrlReg.isHi
277  divDataModule.io.isW := ctrlReg.isW
278  divDataModule.io.out_ready := io.out.ready
279
280  io.in.ready := divDataModule.io.in_ready
281  io.out.valid := divDataModule.io.out_valid
282  io.out.bits.data := divDataModule.io.out_data
283  io.out.bits.uop := uopReg
284}
285