xref: /XiangShan/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala (revision 69b52b93fd0164ab9b16aec21eaf16eeb8fa614d)
1package xiangshan.backend.fu
2
3import chisel3._
4import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage}
5import chisel3.util._
6import utils.SignExt
7import xiangshan.XSModule
8import xiangshan.backend.fu.util.CSA3_2
9
10/** A Radix-4 SRT Integer Divider
11  *
12  * 2 ~ (5 + (len+3)/2) cycles are needed for each division.
13  */
14class SRT4DividerDataModule(len: Int) extends Module {
15  val io = IO(new Bundle() {
16    val src1, src2 = Input(UInt(len.W))
17    val valid, sign, kill_w, kill_r, isHi, isW = Input(Bool())
18    val in_ready = Output(Bool())
19    val out_valid = Output(Bool())
20    val out_data = Output(UInt(len.W))
21    val out_ready = Input(Bool())
22  })
23
24  val (a, b, sign, valid, kill_w, kill_r, isHi, isW) =
25    (io.src1, io.src2, io.sign, io.valid, io.kill_w, io.kill_r, io.isHi, io.isW)
26  val in_fire = valid && io.in_ready
27  val out_fire = io.out_ready && io.out_valid
28
29  val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_finish :: Nil = Enum(7)
30  val state = RegInit(s_idle)
31  val cnt_next = Wire(UInt(log2Up((len + 3) / 2).W))
32  val cnt = RegEnable(cnt_next, state === s_normlize || state === s_recurrence)
33  val rec_enough = cnt_next === 0.U
34  val newReq = in_fire
35
36  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
37    val s = a(len - 1) && sign
38    (s, Mux(s, -a, a))
39  }
40
41  val (aSign, aVal) = abs(a, sign)
42  val (bSign, bVal) = abs(b, sign)
43  val aSignReg = RegEnable(aSign, newReq)
44  val qSignReg = RegEnable(aSign ^ bSign, newReq)
45  val divZero = b === 0.U
46  val divZeroReg = RegEnable(divZero, newReq)
47
48  switch(state) {
49    is(s_idle) {
50      when(in_fire && !kill_w) {
51        state := Mux(divZero, s_finish, s_lzd)
52      }
53    }
54    is(s_lzd) { // leading zero detection
55      state := s_normlize
56    }
57    is(s_normlize) { // shift a/b
58      state := s_recurrence
59    }
60    is(s_recurrence) { // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
61      when(rec_enough) {
62        state := s_recovery_1
63      }
64    }
65    is(s_recovery_1) { // if rem < 0, rem = rem + d
66      state := s_recovery_2
67    }
68    is(s_recovery_2) { // recovery shift
69      state := s_finish
70    }
71    is(s_finish) {
72      when(out_fire) {
73        state := s_idle
74      }
75    }
76  }
77  when(kill_r) {
78    state := s_idle
79  }
80
81  /** Calculate abs(a)/abs(b) by recurrence
82    *
83    * ws, wc: partial remainder in carry-save form,
84    * in recurrence steps, ws/wc = 4ws[j]/4wc[j];
85    * in recovery step, ws/wc = ws[j]/wc[j];
86    * in final step, ws = abs(a)/abs(b).
87    *
88    * d: normlized divisor(1/2<=d<1)
89    *
90    * wLen = 3 integer bits + (len+1) frac bits
91    */
92  def wLen = 3 + len + 1
93
94  val ws, wc = Reg(UInt(wLen.W))
95  val ws_next, wc_next = Wire(UInt(wLen.W))
96  val d = Reg(UInt(wLen.W))
97
98  val aLeadingZeros = RegEnable(
99    next = PriorityEncoder(ws(len - 1, 0).asBools().reverse),
100    enable = state === s_lzd
101  )
102  val bLeadingZeros = RegEnable(
103    next = PriorityEncoder(d(len - 1, 0).asBools().reverse),
104    enable = state === s_lzd
105  )
106  val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt()
107  val isNegDiff = diff(diff.getWidth - 1)
108  val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt())
109  val qBitsIsOdd = quotientBits(0)
110  val recoveryShift = RegEnable(len.U - bLeadingZeros, state === s_normlize)
111  val a_shifted, b_shifted = Wire(UInt(len.W))
112  a_shifted := Mux(isNegDiff,
113    ws(len - 1, 0) << bLeadingZeros,
114    ws(len - 1, 0) << aLeadingZeros
115  )
116  b_shifted := d(len - 1, 0) << bLeadingZeros
117
118  val rem_temp = ws + wc
119  val rem_fixed = Mux(rem_temp(wLen - 1), rem_temp + d, rem_temp)
120  val rem_abs = (ws << recoveryShift) (2 * len, len + 1)
121
122  when(newReq) {
123    ws := Cat(0.U(4.W), Mux(divZero, a, aVal))
124    wc := 0.U
125    d := Cat(0.U(4.W), bVal)
126  }.elsewhen(state === s_normlize) {
127    d := Cat(0.U(3.W), b_shifted, 0.U(1.W))
128    ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1)
129  }.elsewhen(state === s_recurrence) {
130    ws := Mux(rec_enough, ws_next, ws_next << 2)
131    wc := Mux(rec_enough, wc_next, wc_next << 2)
132  }.elsewhen(state === s_recovery_1) {
133    ws := rem_fixed
134  }.elsewhen(state === s_recovery_2) {
135    ws := rem_abs
136  }
137
138  cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U)
139
140  /** Quotient selection
141    *
142    * the quotient selection table use truncated 7-bit remainder
143    * and 3-bit divisor
144    */
145  val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5)
146  val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W))
147  dx2 := d << 1
148  neg_d := (~d).asUInt() // add '1' in carry-save adder later
149  neg_dx2 := neg_d << 1
150
151  val q_sel = Wire(UInt(3.W))
152  val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq(
153    sel_d -> 1.U(2.W),
154    sel_dx2 -> 2.U(2.W)
155  ))
156
157  val w_truncated = (ws(wLen - 1, wLen - 1 - 6) + wc(wLen - 1, wLen - 1 - 6)).asSInt()
158  val d_truncated = d(len - 1, len - 3)
159
160  val qSelTable = Array(
161    Array(12, 4, -4, -13),
162    Array(14, 4, -6, -15),
163    Array(15, 4, -6, -16),
164    Array(16, 4, -6, -18),
165    Array(18, 6, -8, -20),
166    Array(20, 6, -8, -20),
167    Array(20, 8, -8, -22),
168    Array(24, 8, -8, -24)
169  )
170
171  // ge(x): w_truncated >= x
172  var ge = Map[Int, Bool]()
173  for (row <- qSelTable) {
174    for (k <- row) {
175      if (!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W)))
176    }
177  }
178  q_sel := MuxLookup(d_truncated, sel_0,
179    qSelTable.map(x =>
180      MuxCase(sel_neg_dx2, Seq(
181        ge(x(0)) -> sel_dx2,
182        ge(x(1)) -> sel_d,
183        ge(x(2)) -> sel_0,
184        ge(x(3)) -> sel_neg_d
185      ))
186    ).zipWithIndex.map({ case (v, i) => i.U -> v })
187  )
188
189  /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder
190    *
191    * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
192    */
193  val csa = Module(new CSA3_2(wLen))
194  csa.io.in(0) := ws
195  csa.io.in(1) := Cat(wc(wLen - 1, 2), wc_adj)
196  csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq(
197    sel_d -> neg_d,
198    sel_dx2 -> neg_dx2,
199    sel_neg_d -> d,
200    sel_neg_dx2 -> dx2
201  ))
202  ws_next := csa.io.out(0)
203  wc_next := csa.io.out(1) << 1
204
205  // On the fly quotient conversion
206  val q, qm = Reg(UInt(len.W))
207  when(newReq) {
208    q := 0.U
209    qm := 0.U
210  }.elsewhen(state === s_recurrence) {
211    val qMap = Seq(
212      sel_0 -> (q, 0),
213      sel_d -> (q, 1),
214      sel_dx2 -> (q, 2),
215      sel_neg_d -> (qm, 3),
216      sel_neg_dx2 -> (qm, 2)
217    )
218    q := MuxLookup(q_sel, 0.U,
219      qMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
220    )
221    val qmMap = Seq(
222      sel_0 -> (qm, 3),
223      sel_d -> (q, 0),
224      sel_dx2 -> (q, 1),
225      sel_neg_d -> (qm, 2),
226      sel_neg_dx2 -> (qm, 1)
227    )
228    qm := MuxLookup(q_sel, 0.U,
229      qmMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
230    )
231  }.elsewhen(state === s_recovery_1) {
232    q := Mux(rem_temp(wLen - 1), qm, q)
233  }
234
235
236  val remainder = Mux(aSignReg, -ws(len - 1, 0), ws(len - 1, 0))
237  val quotient = Mux(qSignReg, -q, q)
238
239  val res = Mux(isHi,
240    Mux(divZeroReg, ws(len - 1, 0), remainder),
241    Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient)
242  )
243  io.out_data := Mux(isW,
244    SignExt(res(31, 0), len),
245    res
246  )
247  io.in_ready := state === s_idle
248  io.out_valid := state === s_finish
249}
250
251class SRT4Divider(len: Int) extends AbstractDivider(len) {
252
253  val newReq = io.in.fire()
254
255  val uop = io.in.bits.uop
256  val uopReg = RegEnable(uop, newReq)
257  val ctrlReg = RegEnable(ctrl, newReq)
258
259  val divDataModule = Module(new SRT4DividerDataModule(len))
260
261  val kill_w = uop.roqIdx.needFlush(io.redirectIn, io.flushIn)
262  val kill_r = !divDataModule.io.in_ready && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn)
263
264  divDataModule.io.src1 := io.in.bits.src(0)
265  divDataModule.io.src2 := io.in.bits.src(1)
266  divDataModule.io.valid := io.in.valid
267  divDataModule.io.sign := sign
268  divDataModule.io.kill_w := kill_w
269  divDataModule.io.kill_r := kill_r
270  divDataModule.io.isHi := ctrlReg.isHi
271  divDataModule.io.isW := ctrlReg.isW
272  divDataModule.io.out_ready := io.out.ready
273
274  io.in.ready := divDataModule.io.in_ready
275  io.out.valid := divDataModule.io.out_valid
276  io.out.bits.data := divDataModule.io.out_data
277  io.out.bits.uop := uopReg
278}
279