xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision 3b09ed7645058064c52e40aa630c3f7a50fdd476)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import utils._
7
8class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
9  def tagBits = VAddrBits - idxBits - 2
10
11  val tag = UInt(tagBits.W)
12  val idx = UInt(idxBits.W)
13  val offset = UInt(2.W)
14
15  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
16  def getTag(x: UInt) = fromUInt(x).tag
17  def getIdx(x: UInt) = fromUInt(x).idx
18  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
19  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
20}
21
22class Stage1To2IO extends XSBundle {
23  val pc = Output(UInt(VAddrBits.W))
24  val btb = new Bundle {
25    val hits = Output(UInt(FetchWidth.W))
26    val targets = Output(Vec(FetchWidth, UInt(VAddrBits.B)))
27  }
28  val jbtac = new Bundle {
29    val hitIdx = Output(UInt(FetchWidth.W))
30    val target = Output(UInt(VAddrBits.W))
31  }
32  val tage = new Bundle {
33    val hits = Output(UInt(FetchWidth.W))
34    val takens = Output(Vec(FetchWidth, Bool()))
35  }
36  val hist = Output(Vec(FetchWidth, UInt(HistoryLength.W)))
37}
38
39class BPUStage1 extends XSModule {
40  val io = IO(new Bundle() {
41    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
42    // from backend
43    val redirect = Flipped(ValidIO(new Redirect))
44    // from Stage3
45    val flush = Input(Bool())
46    // to ifu, quick prediction result
47    val btbOut = ValidIO(new BranchPrediction)
48    // to Stage2
49    val out = Decoupled(new Stage1To2IO)
50  })
51
52  // TODO: delete this!!!
53  io.in.pc.ready := true.B
54  io.btbOut.valid := false.B
55  io.btbOut.bits := DontCare
56  io.out.valid := false.B
57  io.out.bits := DontCare
58
59}
60
61class Stage2To3IO extends Stage1To2IO {
62}
63
64class BPUStage2 extends XSModule {
65  val io = IO(new Bundle() {
66    // flush from Stage3
67    val flush = Input(Bool())
68    val in = Flipped(Decoupled(new Stage1To2IO))
69    val out = Decoupled(new Stage2To3IO)
70  })
71
72  // flush Stage2 when Stage3 or banckend redirects
73  val flushS2 = BoolStopWatch(io.flush, io.in.valid, startHighPriority = true)
74  io.out.valid := !flushS2 && RegNext(io.in.valid)
75  io.in.ready := !io.out.valid || io.out.fire()
76
77  // do nothing
78  io.out.bits := RegEnable(io.in.bits, io.in.valid)
79}
80
81class BPUStage3 extends XSModule {
82  val io = IO(new Bundle() {
83    val flush = Input(Bool())
84    val in = Flipped(Decoupled(new Stage2To3IO))
85    val predecode = Flipped(ValidIO(new Predecode))
86    val out = ValidIO(new BranchPrediction)
87    // from backend
88    val redirect = Flipped(ValidIO(new Redirect)) // only need isCall here
89    // to Stage1 and Stage2
90    val flushBPU = Output(Bool())
91  })
92
93  // TODO: delete this!!!
94  // io.in.ready := false.B
95  // io.out.valid := false.B
96  io.out.bits := DontCare
97  io.flushBPU := false.B
98
99  val flushS3 = BoolStopWatch(io.flush, io.in.valid, startHighPriority = true)
100  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
101  val validLatch = RegInit(false.B)
102  when (io.in.fire()) { inLatch := io.in.bits }
103  when (io.in.fire()) {
104    validLatch := !io.in.flush
105  }
106  io.out.valid := validLatch && io.predecode.valid && !flushS3
107  io.in.ready := !validLatch || io.out.valid
108
109}
110
111class BPU extends XSModule {
112  val io = IO(new Bundle() {
113    // flush pipeline and update bpu based on redirect signals from brq
114    val redirect = Flipped(ValidIO(new Redirect))
115    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
116    // val predMask = Output(Vec(FetchWidth, Bool()))
117    // val predTargets = Output(Vec(FetchWidth, UInt(VAddrBits.W)))
118    val btbOut = ValidIO(new BranchPrediction)
119    val tageOut = ValidIO(new BranchPrediction)
120
121    // predecode info from icache
122    // TODO: simplify this after implement predecode unit
123    val predecode = Flipped(ValidIO(new Predecode))
124  })
125
126  val s1 = Module(new BPUStage1)
127  val s2 = Module(new BPUStage2)
128  val s3 = Module(new BPUStage3)
129
130  s1.io.redirect <> io.redirect
131  s1.io.flush := s3.io.flushBPU || io.redirect.valid
132  s1.io.in.pc.valid := io.in.pc.valid
133  s1.io.in.pc.bits <> io.in.pc.bits
134  io.btbOut <> s1.io.btbOut
135
136  s1.io.out <> s2.io.in
137  s2.io.flush := s3.io.flushBPU || io.redirect.valid
138
139  s2.io.out <> s3.io.in
140  s3.io.flush := io.redirect.valid
141  s3.io.predecode <> io.predecode
142  io.tageOut <> s3.io.out
143  s3.io.redirect <> io.redirect
144
145  // TODO: delete this and put BTB and JBTAC into Stage1
146  /*
147  val flush = BoolStopWatch(io.redirect.valid, io.in.pc.valid, startHighPriority = true)
148
149  // BTB makes a quick prediction for branch and direct jump, which is
150  // 4-way set-associative, and each way is divided into 4 banks.
151  val btbAddr = new TableAddr(log2Up(BtbSets), BtbBanks)
152  def btbEntry() = new Bundle {
153    val valid = Bool()
154    // TODO: don't need full length of tag and target
155    val tag = UInt(btbAddr.tagBits.W)
156    val _type = UInt(2.W)
157    val target = UInt(VAddrBits.W)
158    val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor
159  }
160
161  val btb = List.fill(BtbBanks)(List.fill(BtbWays)(
162    Module(new SRAMTemplate(btbEntry(), set = BtbSets / BtbBanks, shouldReset = true, holdRead = true, singlePort = true))))
163
164  // val fetchPkgAligned = btbAddr.getBank(io.in.pc.bits) === 0.U
165  val HeadBank = btbAddr.getBank(io.in.pc.bits)
166  val TailBank = btbAddr.getBank(io.in.pc.bits + FetchWidth.U << 2.U - 4.U)
167  for (b <- 0 until BtbBanks) {
168    for (w <- 0 until BtbWays) {
169      btb(b)(w).reset := reset.asBool
170      btb(b)(w).io.r.req.valid := io.in.pc.valid && Mux(TailBank > HeadBank, b.U >= HeadBank && b.U <= TailBank, b.U >= TailBank || b.U <= HeadBank)
171      btb(b)(w).io.r.req.bits.setIdx := btbAddr.getBankIdx(io.in.pc.bits)
172    }
173  }
174  // latch pc for 1 cycle latency when reading SRAM
175  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.valid)
176  val btbRead = Wire(Vec(BtbBanks, Vec(BtbWays, btbEntry())))
177  val btbHits = Wire(Vec(FetchWidth, Bool()))
178  val btbTargets = Wire(Vec(FetchWidth, UInt(VAddrBits.W)))
179  val btbTypes = Wire(Vec(FetchWidth, UInt(2.W)))
180  // val btbPreds = Wire(Vec(FetchWidth, UInt(2.W)))
181  val btbTakens = Wire(Vec(FetchWidth, Bool()))
182  for (b <- 0 until BtbBanks) {
183    for (w <- 0 until BtbWays) {
184      btbRead(b)(w) := btb(b)(w).io.r.resp.data(0)
185    }
186  }
187  for (i <- 0 until FetchWidth) {
188    btbHits(i) := false.B
189    for (b <- 0 until BtbBanks) {
190      for (w <- 0 until BtbWays) {
191        when (b.U === btbAddr.getBank(pcLatch) && btbRead(b)(w).valid && btbRead(b)(w).tag === btbAddr.getTag(Cat(pcLatch(VAddrBits - 1, 2), 0.U(2.W)) + i.U << 2)) {
192          btbHits(i) := !flush && RegNext(btb(b)(w).io.r.req.fire(), init = false.B)
193          btbTargets(i) := btbRead(b)(w).target
194          btbTypes(i) := btbRead(b)(w)._type
195          // btbPreds(i) := btbRead(b)(w).pred
196          btbTakens(i) := (btbRead(b)(w).pred)(1).asBool
197        }.otherwise {
198          btbHits(i) := false.B
199          btbTargets(i) := DontCare
200          btbTypes(i) := DontCare
201          btbTakens(i) := DontCare
202        }
203      }
204    }
205  }
206
207  // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret.
208  val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks)
209  def jbtacEntry() = new Bundle {
210    val valid = Bool()
211    // TODO: don't need full length of tag and target
212    val tag = UInt(jbtacAddr.tagBits.W)
213    val target = UInt(VAddrBits.W)
214  }
215
216  val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = true)))
217
218  (0 until JbtacBanks).map(i => jbtac(i).reset := reset.asBool)
219  (0 until JbtacBanks).map(i => jbtac(i).io.r.req.valid := io.in.pc.valid)
220  (0 until JbtacBanks).map(i => jbtac(i).io.r.req.bits.setIdx := jbtacAddr.getBankIdx(Cat((io.in.pc.bits)(VAddrBits - 1, 2), 0.U(2.W)) + i.U << 2))
221
222  val jbtacRead = Wire(Vec(JbtacBanks, jbtacEntry()))
223  (0 until JbtacBanks).map(i => jbtacRead(i) := jbtac(i).io.r.resp.data(0))
224  val jbtacHits = Wire(Vec(FetchWidth, Bool()))
225  val jbtacTargets = Wire(Vec(FetchWidth, UInt(VAddrBits.W)))
226  val jbtacHeadBank = jbtacAddr.getBank(Cat(pcLatch(VAddrBits - 1, 2), 0.U(2.W)))
227  for (i <- 0 until FetchWidth) {
228    jbtacHits(i) := false.B
229    for (b <- 0 until JbtacBanks) {
230      when (jbtacHeadBank + i.U === b.U) {
231        jbtacHits(i) := jbtacRead(b).valid && jbtacRead(b).tag === jbtacAddr.getTag(Cat(pcLatch(VAddrBits - 1, 2), 0.U(2.W)) + i.U << 2) &&
232          !flush && RegNext(jbtac(b).io.r.req.fire(), init = false.B)
233        jbtacTargets(i) := jbtacRead(b).target
234      }.otherwise {
235        jbtacHits(i) := false.B
236        jbtacTargets(i) := DontCare
237      }
238    }
239  }
240
241  // redirect based on BTB and JBTAC
242  (0 until FetchWidth).map(i => io.predMask(i) := btbHits(i) && Mux(btbTypes(i) === BTBtype.B, btbTakens(i), true.B) || jbtacHits(i))
243  (0 until FetchWidth).map(i => io.predTargets(i) := Mux(btbHits(i) && !(btbTypes(i) === BTBtype.B && !btbTakens(i)), btbTargets(i), jbtacTargets(i)))
244
245
246  // update bpu, including BTB, JBTAC...
247  // 1. update BTB
248  // 1.1 read the selected bank
249  for (b <- 0 until BtbBanks) {
250    for (w <- 0 until BtbWays) {
251      btb(b)(w).io.r.req.valid := io.redirect.valid && btbAddr.getBank(io.redirect.bits.pc) === b.U
252      btb(b)(w).io.r.req.bits.setIdx := btbAddr.getBankIdx(io.redirect.bits.pc)
253    }
254  }
255
256  // 1.2 match redirect pc tag with the 4 tags in a btb line, find a way to write
257  // val redirectLatch = RegEnable(io.redirect.bits, io.redirect.valid)
258  val redirectLatch = RegNext(io.redirect.bits, init = 0.U.asTypeOf(new Redirect))
259  val bankLatch = btbAddr.getBank(redirectLatch.pc)
260  val btbUpdateRead = Wire(Vec(BtbWays, btbEntry()))
261  val btbValids = Wire(Vec(BtbWays, Bool()))
262  val btbUpdateTagHits = Wire(Vec(BtbWays, Bool()))
263  for (b <- 0 until BtbBanks) {
264    for (w <- 0 until BtbWays) {
265      when (b.U === bankLatch) {
266        btbUpdateRead(w) := btb(b)(w).io.r.resp.data(0)
267        btbValids(w) := btbUpdateRead(w).valid && RegNext(btb(b)(w).io.r.req.fire(), init = false.B)
268      }.otherwise {
269        btbUpdateRead(w) := 0.U.asTypeOf(btbEntry())
270        btbValids(w) := false.B
271      }
272    }
273  }
274  (0 until BtbWays).map(w => btbUpdateTagHits(w) := btbValids(w) && btbUpdateRead(w).tag === btbAddr.getTag(redirectLatch.pc))
275  // val btbWriteWay = Wire(Vec(BtbWays, Bool()))
276  val btbWriteWay = Wire(UInt(BtbWays.W))
277  val btbInvalids = ~ btbValids.asUInt
278  when (btbUpdateTagHits.asUInt.orR) {
279    // tag hits
280    btbWriteWay := btbUpdateTagHits.asUInt
281  }.elsewhen (!btbValids.asUInt.andR) {
282    // no tag hits but there are free entries
283    btbWriteWay := Mux(btbInvalids >= 8.U, "b1000".U,
284      Mux(btbInvalids >= 4.U, "b0100".U,
285      Mux(btbInvalids >= 2.U, "b0010".U, "b0001".U)))
286  }.otherwise {
287    // no tag hits and no free entry, select a victim way
288    btbWriteWay := UIntToOH(LFSR64()(log2Up(BtbWays) - 1, 0))
289  }
290
291  // 1.3 calculate new 2-bit counter value
292  val btbWrite = WireInit(0.U.asTypeOf(btbEntry()))
293  btbWrite.valid := true.B
294  btbWrite.tag := btbAddr.getTag(redirectLatch.pc)
295  btbWrite._type := redirectLatch._type
296  btbWrite.target := redirectLatch.brTarget
297  val oldPred = WireInit("b01".U)
298  oldPred := PriorityMux(btbWriteWay.asTypeOf(Vec(BtbWays, Bool())), btbUpdateRead.map{ e => e.pred })
299  val newPred = Mux(redirectLatch.taken, Mux(oldPred === "b11".U, "b11".U, oldPred + 1.U),
300    Mux(oldPred === "b00".U, "b00".U, oldPred - 1.U))
301  btbWrite.pred := Mux(btbUpdateTagHits.asUInt.orR && redirectLatch._type === BTBtype.B, newPred, "b01".U)
302
303  // 1.4 write BTB
304  for (b <- 0 until BtbBanks) {
305    for (w <- 0 until BtbWays) {
306      when (b.U === bankLatch) {
307        btb(b)(w).io.w.req.valid := OHToUInt(btbWriteWay) === w.U &&
308          RegNext(io.redirect.valid, init = false.B) &&
309          (redirectLatch._type === BTBtype.B || redirectLatch._type === BTBtype.J)
310        btb(b)(w).io.w.req.bits.setIdx := btbAddr.getBankIdx(redirectLatch.pc)
311        btb(b)(w).io.w.req.bits.data := btbWrite
312      }.otherwise {
313        btb(b)(w).io.w.req.valid := false.B
314        btb(b)(w).io.w.req.bits.setIdx := DontCare
315        btb(b)(w).io.w.req.bits.data := DontCare
316      }
317    }
318  }
319
320  // 2. update JBTAC
321  val jbtacWrite = WireInit(0.U.asTypeOf(jbtacEntry()))
322  jbtacWrite.valid := true.B
323  jbtacWrite.tag := jbtacAddr.getTag(io.redirect.bits.pc)
324  jbtacWrite.target := io.redirect.bits.target
325  (0 until JbtacBanks).map(b =>
326    jbtac(b).io.w.req.valid := io.redirect.valid &&
327      b.U === jbtacAddr.getBank(io.redirect.bits.pc) &&
328      io.redirect.bits._type === BTBtype.I)
329  (0 until JbtacBanks).map(b => jbtac(b).io.w.req.bits.setIdx := jbtacAddr.getBankIdx(io.redirect.bits.pc))
330  (0 until JbtacBanks).map(b => jbtac(b).io.w.req.bits.data := jbtacWrite)
331  */
332}
333