xref: /XiangShan/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala (revision c6d439803a044ea209139672b25e35fe8d7f4aa0)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3*
4* XiangShan is licensed under Mulan PSL v2.
5* You can use this software according to the terms and conditions of the Mulan PSL v2.
6* You may obtain a copy of Mulan PSL v2 at:
7*          http://license.coscl.org.cn/MulanPSL2
8*
9* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
10* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
11* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
12*
13* See the Mulan PSL v2 for more details.
14***************************************************************************************/
15
16package xiangshan.backend.fu
17
18import chipsalliance.rocketchip.config.Parameters
19import chisel3._
20import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage}
21import chisel3.util._
22import utils.SignExt
23import xiangshan.XSModule
24import xiangshan.backend.fu.util.CSA3_2
25
26/** A Radix-4 SRT Integer Divider
27  *
28  * 2 ~ (5 + (len+3)/2) cycles are needed for each division.
29  */
30class SRT4DividerDataModule(len: Int) extends Module {
31  val io = IO(new Bundle() {
32    val src = Vec(2, Input(UInt(len.W)))
33    val valid, sign, kill_w, kill_r, isHi, isW = Input(Bool())
34    val in_ready = Output(Bool())
35    val out_valid = Output(Bool())
36    val out_data = Output(UInt(len.W))
37    val out_ready = Input(Bool())
38  })
39
40  val (a, b, sign, valid, kill_w, kill_r, isHi, isW) =
41    (io.src(0), io.src(1), io.sign, io.valid, io.kill_w, io.kill_r, io.isHi, io.isW)
42  val in_fire = valid && io.in_ready
43  val out_fire = io.out_ready && io.out_valid
44
45  // s_pad_* is not used
46  val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery_1 :: s_recovery_2 :: s_pad_1 :: s_pad_2 :: s_finish :: Nil = Enum(9)
47  require(s_finish.litValue() == 8)
48
49  val state = RegInit(s_idle)
50  val finished = state(3).asBool // state === s_finish
51
52  val cnt_next = Wire(UInt(log2Up((len + 3) / 2).W))
53  val cnt = RegEnable(cnt_next, state === s_normlize || state === s_recurrence)
54  val rec_enough = cnt_next === 0.U
55  val newReq = in_fire
56
57  def abs(a: UInt, sign: Bool): (Bool, UInt) = {
58    val s = a(len - 1) && sign
59    (s, Mux(s, -a, a))
60  }
61
62  val (aSign, aVal) = abs(a, sign)
63  val (bSign, bVal) = abs(b, sign)
64  val aSignReg = RegEnable(aSign, newReq)
65  val qSignReg = RegEnable(aSign ^ bSign, newReq)
66  val divZero = b === 0.U
67  val divZeroReg = RegEnable(divZero, newReq)
68
69  switch(state) {
70    is(s_idle) {
71      when(in_fire && !kill_w) {
72        state := Mux(divZero, s_finish, s_lzd)
73      }
74    }
75    is(s_lzd) { // leading zero detection
76      state := s_normlize
77    }
78    is(s_normlize) { // shift a/b
79      state := s_recurrence
80    }
81    is(s_recurrence) { // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
82      when(rec_enough) {
83        state := s_recovery_1
84      }
85    }
86    is(s_recovery_1) { // if rem < 0, rem = rem + d
87      state := s_recovery_2
88    }
89    is(s_recovery_2) { // recovery shift
90      state := s_finish
91    }
92    is(s_finish) {
93      when(out_fire) {
94        state := s_idle
95      }
96    }
97  }
98  when(kill_r) {
99    state := s_idle
100  }
101
102  /** Calculate abs(a)/abs(b) by recurrence
103    *
104    * ws, wc: partial remainder in carry-save form,
105    * in recurrence steps, ws/wc = 4ws[j]/4wc[j];
106    * in recovery step, ws/wc = ws[j]/wc[j];
107    * in final step, ws = abs(a)/abs(b).
108    *
109    * d: normlized divisor(1/2<=d<1)
110    *
111    * wLen = 3 integer bits + (len+1) frac bits
112    */
113  def wLen = 3 + len + 1
114
115  val ws, wc = Reg(UInt(wLen.W))
116  val ws_next, wc_next = Wire(UInt(wLen.W))
117  val d = Reg(UInt(wLen.W))
118
119  val aLeadingZeros = RegEnable(
120    next = PriorityEncoder(ws(len - 1, 0).asBools().reverse),
121    enable = state === s_lzd
122  )
123  val bLeadingZeros = RegEnable(
124    next = PriorityEncoder(d(len - 1, 0).asBools().reverse),
125    enable = state === s_lzd
126  )
127  val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt()
128  val isNegDiff = diff(diff.getWidth - 1)
129  val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt())
130  val qBitsIsOdd = quotientBits(0)
131  val recoveryShift = RegEnable(len.U - bLeadingZeros, state === s_normlize)
132  val a_shifted, b_shifted = Wire(UInt(len.W))
133  a_shifted := Mux(isNegDiff,
134    ws(len - 1, 0) << bLeadingZeros,
135    ws(len - 1, 0) << aLeadingZeros
136  )
137  b_shifted := d(len - 1, 0) << bLeadingZeros
138
139  val rem_temp = ws + wc
140  val rem_fixed = Mux(rem_temp(wLen - 1), rem_temp + d, rem_temp)
141  val rem_abs = (ws << recoveryShift) (2 * len, len + 1)
142
143  when(newReq) {
144    ws := Cat(0.U(4.W), Mux(divZero, a, aVal))
145    wc := 0.U
146    d := Cat(0.U(4.W), bVal)
147  }.elsewhen(state === s_normlize) {
148    d := Cat(0.U(3.W), b_shifted, 0.U(1.W))
149    ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1)
150  }.elsewhen(state === s_recurrence) {
151    ws := Mux(rec_enough, ws_next, ws_next << 2)
152    wc := Mux(rec_enough, wc_next, wc_next << 2)
153  }.elsewhen(state === s_recovery_1) {
154    ws := rem_fixed
155  }.elsewhen(state === s_recovery_2) {
156    ws := rem_abs
157  }
158
159  cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U)
160
161  /** Quotient selection
162    *
163    * the quotient selection table use truncated 7-bit remainder
164    * and 3-bit divisor
165    */
166  val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5)
167  val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W))
168  dx2 := d << 1
169  neg_d := (~d).asUInt() // add '1' in carry-save adder later
170  neg_dx2 := neg_d << 1
171
172  val q_sel = Wire(UInt(3.W))
173  val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq(
174    sel_d -> 1.U(2.W),
175    sel_dx2 -> 2.U(2.W)
176  ))
177
178  val w_truncated = (ws(wLen - 1, wLen - 1 - 6) + wc(wLen - 1, wLen - 1 - 6)).asSInt()
179  val d_truncated = d(len - 1, len - 3)
180
181  val qSelTable = Array(
182    Array(12, 4, -4, -13),
183    Array(14, 4, -6, -15),
184    Array(15, 4, -6, -16),
185    Array(16, 4, -6, -18),
186    Array(18, 6, -8, -20),
187    Array(20, 6, -8, -20),
188    Array(20, 8, -8, -22),
189    Array(24, 8, -8, -24)
190  )
191
192  // ge(x): w_truncated >= x
193  var ge = Map[Int, Bool]()
194  for (row <- qSelTable) {
195    for (k <- row) {
196      if (!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W)))
197    }
198  }
199  q_sel := MuxLookup(d_truncated, sel_0,
200    qSelTable.map(x =>
201      MuxCase(sel_neg_dx2, Seq(
202        ge(x(0)) -> sel_dx2,
203        ge(x(1)) -> sel_d,
204        ge(x(2)) -> sel_0,
205        ge(x(3)) -> sel_neg_d
206      ))
207    ).zipWithIndex.map({ case (v, i) => i.U -> v })
208  )
209
210  /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder
211    *
212    * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d
213    */
214  val csa = Module(new CSA3_2(wLen))
215  csa.io.in(0) := ws
216  csa.io.in(1) := Cat(wc(wLen - 1, 2), wc_adj)
217  csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq(
218    sel_d -> neg_d,
219    sel_dx2 -> neg_dx2,
220    sel_neg_d -> d,
221    sel_neg_dx2 -> dx2
222  ))
223  ws_next := csa.io.out(0)
224  wc_next := csa.io.out(1) << 1
225
226  // On the fly quotient conversion
227  val q, qm = Reg(UInt(len.W))
228  when(newReq) {
229    q := 0.U
230    qm := 0.U
231  }.elsewhen(state === s_recurrence) {
232    val qMap = Seq(
233      sel_0 -> (q, 0),
234      sel_d -> (q, 1),
235      sel_dx2 -> (q, 2),
236      sel_neg_d -> (qm, 3),
237      sel_neg_dx2 -> (qm, 2)
238    )
239    q := MuxLookup(q_sel, 0.U,
240      qMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
241    )
242    val qmMap = Seq(
243      sel_0 -> (qm, 3),
244      sel_d -> (q, 0),
245      sel_dx2 -> (q, 1),
246      sel_neg_d -> (qm, 2),
247      sel_neg_dx2 -> (qm, 1)
248    )
249    qm := MuxLookup(q_sel, 0.U,
250      qmMap.map(m => m._1 -> Cat(m._2._1(len - 3, 0), m._2._2.U(2.W)))
251    )
252  }.elsewhen(state === s_recovery_1) {
253    q := Mux(rem_temp(wLen - 1), qm, q)
254  }
255
256
257  val remainder = Mux(aSignReg, -ws(len - 1, 0), ws(len - 1, 0))
258  val quotient = Mux(qSignReg, -q, q)
259
260  val res = Mux(isHi,
261    Mux(divZeroReg, ws(len - 1, 0), remainder),
262    Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient)
263  )
264  io.out_data := Mux(isW,
265    SignExt(res(31, 0), len),
266    res
267  )
268  io.in_ready := state === s_idle
269  io.out_valid := finished // state === s_finish
270}
271
272class SRT4Divider(len: Int)(implicit p: Parameters) extends AbstractDivider(len) {
273
274  val newReq = io.in.fire()
275
276  val uop = io.in.bits.uop
277  val uopReg = RegEnable(uop, newReq)
278  val ctrlReg = RegEnable(ctrl, newReq)
279
280  val divDataModule = Module(new SRT4DividerDataModule(len))
281
282  val kill_w = uop.roqIdx.needFlush(io.redirectIn, io.flushIn)
283  val kill_r = !divDataModule.io.in_ready && uopReg.roqIdx.needFlush(io.redirectIn, io.flushIn)
284
285  divDataModule.io.src(0) := io.in.bits.src(0)
286  divDataModule.io.src(1) := io.in.bits.src(1)
287  divDataModule.io.valid := io.in.valid
288  divDataModule.io.sign := sign
289  divDataModule.io.kill_w := kill_w
290  divDataModule.io.kill_r := kill_r
291  divDataModule.io.isHi := ctrlReg.isHi
292  divDataModule.io.isW := ctrlReg.isW
293  divDataModule.io.out_ready := io.out.ready
294
295  io.in.ready := divDataModule.io.in_ready
296  io.out.valid := divDataModule.io.out_valid
297  io.out.bits.data := divDataModule.io.out_data
298  io.out.bits.uop := uopReg
299}
300