xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision 96a5133979de9ca0b86ba19e9222a2e57b4b871f)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import xiangshan.utils._
7import xiangshan.backend.ALUOpType
8import utils._
9
10class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
11  def tagBits = VAddrBits - idxBits - 2
12
13  val tag = UInt(tagBits.W)
14  val idx = UInt(idxBits.W)
15  val offset = UInt(2.W)
16
17  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
18  def getTag(x: UInt) = fromUInt(x).tag
19  def getIdx(x: UInt) = fromUInt(x).idx
20  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
21  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
22}
23
24class Stage1To2IO extends XSBundle {
25  val pc = Output(UInt(VAddrBits.W))
26  val btb = new Bundle {
27    val hits = Output(UInt(FetchWidth.W))
28    val targets = Output(Vec(FetchWidth, UInt(VAddrBits.W)))
29  }
30  val jbtac = new Bundle {
31    val hitIdx = Output(UInt(FetchWidth.W))
32    val target = Output(UInt(VAddrBits.W))
33  }
34  val tage = new Bundle {
35    val hits = Output(UInt(FetchWidth.W))
36    val takens = Output(Vec(FetchWidth, Bool()))
37  }
38  val hist = Output(Vec(FetchWidth, UInt(HistoryLength.W)))
39  val btbPred = ValidIO(new BranchPrediction)
40}
41
42class BPUStage1 extends XSModule {
43  val io = IO(new Bundle() {
44    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
45    // from backend
46    val redirectInfo = Flipped(new RedirectInfo)
47    // from Stage3
48    val flush = Input(Bool())
49    val s3RollBackHist = Input(UInt(HistoryLength.W))
50    // to ifu, quick prediction result
51    val s1OutPred = ValidIO(new BranchPrediction)
52    // to Stage2
53    val out = Decoupled(new Stage1To2IO)
54  })
55
56<<<<<<< Updated upstream
57=======
58  // flush Stage1 when io.flush || io.redirect.valid
59
60  val predictWidth = 8
61  def btbTarget = new Bundle {
62    val addr = UInt(VAddrBits.W)
63    val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor
64    val _type = UInt(2.W)
65    val offset = if (offsetBits()) Some(UInt(offsetBits().W)) else None
66
67    def offsetBits() = log2Up(FetchWidth / predictWidth)
68  }
69
70  def btbEntry() = new Bundle {
71    val valid = Bool()
72    // TODO: don't need full length of tag and target
73    val tag = UInt(btbAddr.tagBits.W)
74    val target = Vec(predictWidth, btbTarget)
75  }
76
77  val btb = List.fill(BtbWays)(List.fill(BtbBanks)(
78    Module(new SRAMTemplate(btbEntry(), set = BtbSets / BtbBanks, shouldReset = true, holdRead = true, singlePort = false))))
79
80  // val btbReadBank = btbAddr.getBank(io.in.pc.bits)
81
82  // BTB read requests
83  for (w <- 0 until BtbWays) {
84    for (b <- 0 until BtbBanks) {
85      btb(w)(b).reset := reset.asBool
86      btb(w)(b).io.r.req.valid := io.in.pc.valid && b.U === btbAddr.getBank(io.in.pc.bits)
87      btb(w)(b).io.r.req.bits.setIdx := btbAddr.getBankIdx(io.in.pc.bits)
88    }
89  }
90
91  // latch pc for 1 cycle latency when reading SRAM
92  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.valid)
93  val btbRead = Wire(Vec(BtbWays, Vec(BtbBanks, btbEntry())))
94  val btbHits = Wire(Vec(BtbWays, Bool()))
95
96  // #(predictWidth) results
97  val btbTargets = Wire(Vec(predictWidth, UInt(VAddrBits.W)))
98  val btbTypes = Wire(Vec(predictWidth, UInt(2.W)))
99  // val btbPreds = Wire(Vec(FetchWidth, UInt(2.W)))
100  val btbTakens = Wire(Vec(predictWidth, Bool()))
101  for (w <- 0 until BtbWays) {
102    for (b <- 0 until BtbBanks) {
103      btbRead(w)(b) := btb(w)(b).io.r.resp.data(0)
104    }
105  }
106  for (i <- 0 until predictWidth) {
107    // btbHits(i) := false.B
108    for (w <- 0 until BtbWays) {
109      btbHits(w) := false.B
110      for (b <- 0 until BtbBanks) {
111        when (b.U === btbAddr.getBank(pcLatch) && btbRead(w)(b).valid && btbRead(w)(b).tag === btbAddr.getTag(pcLatch))) {
112          btbHits(w) := !flush && RegNext(btb(w)(b).io.r.req.fire(), init = false.B)
113          btbTargets(i) := btbRead(w)(b).target(i)
114          btbTypes(i) := btbRead(w)(b)._type(i)
115          btbTakens(i) := (btbRead(b)(w).pred(i))(1).asBool
116        }.otherwise {
117          btbHits(w) := false.B
118          btbTargets(i) := DontCare
119          btbTypes(i) := DontCare
120          btbTakens(i) := DontCare
121        }
122      }
123    }
124  }
125
126  val btbTakenidx := MuxCase(0.U, (0 until predictWidth).map(i => btbTakens(i)) zip (0.U until predictWidth.U))
127  val btbTakenTarget := btbTargets(btbTakenidx)
128  val btbTakenType := btbTypes(btbTakenidx)
129
130  // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret.
131  val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks)
132  def jbtacEntry() = new Bundle {
133    val valid = Bool()
134    // TODO: don't need full length of tag and target
135    val tag = UInt(jbtacAddr.tagBits.W)
136    val target = UInt(VAddrBits.W)
137    val offset = UInt(log2Up(FetchWidth).W)
138  }
139
140  val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = false)))
141
142  (0 until JbtacBanks).map(b => jbtac(b).reset := reset.asBool)
143  (0 until JbtacBanks).map(b => jbtac(b).io.r.req.valid := io.in.pc.valid && b.U === jbtacAddr.getBank(io.in.pc.bits))
144  (0 until JbtacBanks).map(b => jbtac(b).io.r.req.bits.setIdx := jbtacAddr.getBankIdx(io.in.pc.bits))
145  val jbtacRead = Wire(Vec(JbtacBanks, jbtacEntry()))
146  (0 until JbtacBanks).map(b => jbtacRead(b) := jbtac(b).io.r.resp.data(0))
147
148  val jbtacHits = Wire(Vec(JbtacBanks, Bool()))
149  val jbtacHitIdxs = Wire(UInt(log2Up(FetchWidth).W))
150  val jbtacTargets = Wire(UInt(VAddrBits.W))
151
152  val jbtacHit = Wire(Bool())
153  val jbtacHitIdx = Wire(UInt(log2Up(FetchWidth).W))
154  val jbtacTarget = Wire(UInt(VAddrBits.W))
155
156  jbtacHit := jbtacRead(b).valid
157  jbtacHitIdx := jbtacRead.offset
158  jbtacTarget := jbtacRead.target
159  for (b <- 0 until JbtacBanks) {
160    when (jbtacAddr.getBank(pcLatch) === b.U && jbtacRead(b).valid && jbtacRead(b).ta === jbtacAddr.getTag(pcLatch)) {
161      jbtacHit := !flush && RegNext(jbtac(b).io.r.req.fire(), init = false.B)
162      jbtacTarget := jbtacRead(b).target
163    }.otherwise {
164      jbtacHits(i) := false.B
165      jbtacTargets(i) := DontCare
166    }
167  }
168
169  // redirect based on BTB and JBTAC
170  (0 until FetchWidth).map(i => io.predMask(i) := btbHits(i) && Mux(btbTypes(i) === BTBtype.B, btbTakens(i), true.B) || jbtacHits(i))
171  (0 until FetchWidth).map(i => io.predTargets(i) := Mux(btbHits(i) && !(btbTypes(i) === BTBtype.B && !btbTakens(i)), btbTargets(i), jbtacTargets(i)))
172
173  def getLowerMask(idx: UInt, len: Int) = (0 until len).map(i => idx >> i.U).reduce(_|_)
174  def getLowestBit(idx: UInt, len: Int) = Mux(idx(0), 1.U(len.W), Reverse(((0 until len).map(i => Reverse(idx(len - 1, 0)) >> i.U).reduce(_|_) + 1.U) >> 1.U))
175
176
177  io.s1OutPred.valid := RegNext(io.in.pc.valid)
178  io.s1OutPred.redirect := btbHits.orR && btbTakens.orR
179  io.s1OutPred.instrValid := ~getLowerMask(btbTakenidx, FetchWidth)
180  io.s1OutPred.target := btbTakenTarget
181  io.s1OutPred.hist := DontCare
182  io.s1OutPred.rasSp := DontCare
183  io.s1OutPred.rasTopCtr := DontCare
184
185
186
187>>>>>>> Stashed changes
188  // TODO: delete this!!!
189  io.in.pc.ready := true.B
190
191  io.out.valid := false.B
192  io.out.bits := DontCare
193
194  // flush Stage1 when io.flush
195  val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
196
197  // global history register
198  val ghr = RegInit(0.U(HistoryLength.W))
199  // modify updateGhr and newGhr when updating ghr
200  val updateGhr = WireInit(false.B)
201  val newGhr = WireInit(0.U(HistoryLength.W))
202  when (updateGhr) { ghr := newGhr }
203  // use hist as global history!!!
204  val hist = Mux(updateGhr, newGhr, ghr)
205
206  // Tage predictor
207  val tage = Module(new Tage)
208  tage.io.req.valid := io.in.pc.fire()
209  tage.io.req.bits.pc := io.in.pc.bits
210  tage.io.req.bits.hist := hist
211  tage.io.redirectInfo <> io.redirectInfo
212  io.out.bits.tage <> tage.io.out
213  io.s1OutPred.bits.tageMeta := tage.io.meta
214
215  // flush Stage1 when io.flush || io.redirect.valid
216  val flush = flushS1 || io.redirectInfo.valid
217
218  val btbAddr = new TableAddr(log2Up(BtbSets), BtbBanks)
219  val predictWidth = FetchWidth
220  def btbTarget = new Bundle {
221    val addr = UInt(VAddrBits.W)
222    val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor
223    val _type = UInt(2.W)
224    val offset = UInt(offsetBits().W) // Could be zero
225
226    def offsetBits() = log2Up(FetchWidth / predictWidth)
227  }
228
229  def btbEntry() = new Bundle {
230    val valid = Bool()
231    // TODO: don't need full length of tag and target
232    val tag = UInt(btbAddr.tagBits.W)
233    val target = Vec(predictWidth, btbTarget)
234  }
235
236  val btb = List.fill(BtbWays)(List.fill(BtbBanks)(
237    Module(new SRAMTemplate(btbEntry(), set = BtbSets / BtbBanks, shouldReset = true, holdRead = true, singlePort = false))))
238
239  // val btbReadBank = btbAddr.getBank(io.in.pc.bits)
240
241  // BTB read requests
242  // read addr comes from pc[6:2]
243  // read 4 ways in parallel
244  (0 until BtbWays).map(
245    w => (0 until BtbBanks).map(
246      b => {
247        btb(w)(b).reset := reset.asBool
248        btb(w)(b).io.r.req.valid := io.in.pc.valid && b.U === btbAddr.getBank(io.in.pc.bits)
249        btb(w)(b).io.r.req.bits.setIdx := btbAddr.getBankIdx(io.in.pc.bits)
250      }))
251
252  // latch pc for 1 cycle latency when reading SRAM
253  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.valid)
254  // Entries read from SRAM
255  val btbRead = Wire(Vec(BtbWays, Vec(BtbBanks, btbEntry())))
256  // 1/4 hit
257  val btbHits = Wire(Vec(BtbWays, Bool()))
258
259  // #(predictWidth) results
260  val btbTargets = Wire(Vec(predictWidth, UInt(VAddrBits.W)))
261  val btbTypes = Wire(Vec(predictWidth, UInt(2.W)))
262  // val btbPreds = Wire(Vec(FetchWidth, UInt(2.W)))
263  val btbTakens = Wire(Vec(predictWidth, Bool()))
264
265  val btbHitWay = Wire(UInt(log2Up(BtbWays).W))
266
267  (0 until BtbWays).map(
268    w => (0 until BtbBanks).map(
269      b => btbRead(w)(b) := btb(w)(b).io.r.resp.data(0)
270    )
271  )
272
273
274  for (w <- 0 until BtbWays) {
275    for (b <- 0 until BtbBanks) {
276      when (b.U === btbAddr.getBank(pcLatch) && btbRead(w)(b).valid && btbRead(w)(b).tag === btbAddr.getTag(pcLatch)) {
277        btbHits(w) := !flush && RegNext(btb(w)(b).io.r.req.fire(), init = false.B)
278        btbHitWay := w.U
279        for (i <- 0 until predictWidth) {
280          btbTargets(i) := btbRead(w)(b).target(i).addr
281          btbTypes(i) := btbRead(w)(b).target(i)._type
282          btbTakens(i) := (btbRead(b)(w).target(i).pred)(1).asBool
283        }
284      }.otherwise {
285        btbHits(w) := false.B
286        btbHitWay := DontCare
287        for (i <- 0 until predictWidth) {
288          btbTargets(i) := DontCare
289          btbTypes(i) := DontCare
290          btbTakens(i) := DontCare
291        }
292      }
293    }
294  }
295
296
297  val btbHit = btbHits.reduce(_|_)
298
299  // Priority mux which corresponds with inst orders
300  // BTB only produce one single prediction
301  val btbTakenTarget = MuxCase(0.U, btbTakens zip btbTargets)
302  val btbTakenType   = MuxCase(0.U, btbTakens zip btbTypes)
303  val btbTaken       = btbTakens.reduce(_|_)
304  // Record which inst is predicted taken
305  val btbTakenIdx = MuxCase(0.U, btbTakens zip (0 until predictWidth).map(_.U))
306
307  // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret.
308  val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks)
309  def jbtacEntry() = new Bundle {
310    val valid = Bool()
311    // TODO: don't need full length of tag and target
312    val tag = UInt(jbtacAddr.tagBits.W)
313    val target = UInt(VAddrBits.W)
314    val offset = UInt(log2Up(FetchWidth).W)
315  }
316
317  val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = false)))
318
319  val jbtacRead = Wire(Vec(JbtacBanks, jbtacEntry()))
320
321  val jbtacFire = Wire(Vec(JbtacBanks, RegInit(Bool(), init=false.B)))
322  // Only read one bank
323  (0 until JbtacBanks).map(
324    b=>{
325      jbtac(b).reset := reset.asBool
326      jbtac(b).io.r.req.valid := io.in.pc.valid && b.U === jbtacAddr.getBank(io.in.pc.bits)
327      jbtac(b).io.r.req.bits.setIdx := jbtacAddr.getBankIdx(io.in.pc.bits)
328      jbtacFire(b) := jbtac(b).io.r.req.fire()
329      jbtacRead(b) := jbtac(b).io.r.resp.data(0)
330    }
331  )
332
333  val jbtacBank = jbtacAddr.getBank(pcLatch)
334  val jbtacHit = jbtacRead(jbtacBank).valid && jbtacRead(jbtacBank).tag === jbtacAddr.getTag(pcLatch) && !flush && jbtacFire(jbtacBank)
335  val jbtacHitIdx = jbtacRead(jbtacBank).offset
336  val jbtacTarget = jbtacRead(jbtacBank).target
337
338
339  // redirect based on BTB and JBTAC
340  // (0 until FetchWidth).map(i => io.predMask(i) := btbHits(i) && Mux(btbTypes(i) === BTBtype.B, btbTakens(i), true.B) || jbtacHits(i))
341  // (0 until FetchWidth).map(i => io.predTargets(i) := Mux(btbHits(i) && !(btbTypes(i) === BTBtype.B && !btbTakens(i)), btbTargets(i), jbtacTargets(i)))
342
343  io.out.valid := RegNext(io.in.pc.valid) && !flush
344
345  io.s1OutPred.valid := RegNext(io.in.pc.valid)
346  io.s1OutPred.bits.redirect := btbHit && btbTaken || jbtacHit
347  io.s1OutPred.bits.instrValid := ~LowerMask(btbTakenIdx, FetchWidth) & ~LowerMask(jbtacHitIdx, FetchWidth)
348  io.s1OutPred.bits.target := Mux(btbTakenIdx < jbtacHitIdx, btbTakenTarget, jbtacTarget)
349  io.s1OutPred.bits.hist := DontCare
350  io.s1OutPred.bits.rasSp := DontCare
351  io.s1OutPred.bits.rasTopCtr := DontCare
352
353  io.out.bits.pc := pcLatch
354  io.out.bits.btb.hits := btbHit
355  (0 until FetchWidth).map(i=>io.out.bits.btb.targets(i) := btbTargets(i))
356  io.out.bits.jbtac.hitIdx := jbtacHitIdx
357  io.out.bits.jbtac.target := jbtacTarget
358  io.out.bits.tage := DontCare
359  io.out.bits.hist := DontCare
360  io.out.bits.btbPred := io.s1OutPred
361
362
363  // TODO: delete this!!!
364  io.in.pc.ready := true.B
365
366}
367
368class Stage2To3IO extends Stage1To2IO {
369}
370
371class BPUStage2 extends XSModule {
372  val io = IO(new Bundle() {
373    // flush from Stage3
374    val flush = Input(Bool())
375    val in = Flipped(Decoupled(new Stage1To2IO))
376    val out = Decoupled(new Stage2To3IO)
377  })
378
379  // flush Stage2 when Stage3 or banckend redirects
380  val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
381  io.out.valid := !flushS2 && RegNext(io.in.fire())
382  io.in.ready := !io.out.valid || io.out.fire()
383
384  // do nothing
385  io.out.bits := RegEnable(io.in.bits, io.in.fire())
386}
387
388class BPUStage3 extends XSModule {
389  val io = IO(new Bundle() {
390    val flush = Input(Bool())
391    val in = Flipped(Decoupled(new Stage2To3IO))
392    val out = ValidIO(new BranchPrediction)
393    // from icache
394    val predecode = Flipped(ValidIO(new Predecode))
395    // from backend
396    val redirectInfo = Flipped(new RedirectInfo)
397    // to Stage1 and Stage2
398    val flushBPU = Output(Bool())
399    // to Stage1, restore ghr in stage1 when flushBPU is valid
400    val s1RollBackHist = Output(UInt(HistoryLength.W))
401  })
402
403  val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
404  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
405  val validLatch = RegInit(false.B)
406  when (io.in.fire()) { inLatch := io.in.bits }
407  when (io.in.fire()) {
408    validLatch := !io.flush
409  }.elsewhen (io.out.valid) {
410    validLatch := false.B
411  }
412  io.out.valid := validLatch && io.predecode.valid && !flushS3
413  io.in.ready := !validLatch || io.out.valid
414
415  // RAS
416  // TODO: split retAddr and ctr
417  def rasEntry() = new Bundle {
418    val retAddr = UInt(VAddrBits.W)
419    val ctr = UInt(8.W) // layer of nested call functions
420  }
421  val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
422  val sp = Counter(RasSize)
423  val rasTop = ras(sp.value)
424  val rasTopAddr = rasTop.retAddr
425
426  // get the first taken branch/jal/call/jalr/ret in a fetch line
427  // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded.
428  // brNotTakenIdx indicates all the not-taken branches before the first jump instruction.
429  val brIdx = inLatch.btb.hits & Cat(io.predecode.bits.fuTypes.map { t => ALUOpType.isBranch(t) }).asUInt & io.predecode.bits.mask
430  val brTakenIdx = LowestBit(brIdx & inLatch.tage.takens.asUInt, FetchWidth)
431  val jalIdx = LowestBit(inLatch.btb.hits & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.jal }).asUInt & io.predecode.bits.mask, FetchWidth)
432  val callIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.mask & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.call }).asUInt, FetchWidth)
433  val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & io.predecode.bits.mask & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.jalr }).asUInt, FetchWidth)
434  val retIdx = LowestBit(io.predecode.bits.mask & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.ret }).asUInt, FetchWidth)
435
436  val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, FetchWidth)
437  val brNotTakenIdx = brIdx & ~inLatch.tage.takens.asUInt & LowerMask(jmpIdx, FetchWidth)
438
439  io.out.bits.redirect := jmpIdx.orR.asBool
440  io.out.bits.target := Mux(jmpIdx === retIdx, rasTopAddr,
441    Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
442    Mux(jmpIdx === 0.U, inLatch.pc + 4.U, // TODO: RVC
443    PriorityMux(jmpIdx, inLatch.btb.targets))))
444  io.out.bits.instrValid := LowerMask(jmpIdx, FetchWidth).asTypeOf(Vec(FetchWidth, Bool()))
445  io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
446  //io.out.bits._type := Mux(jmpIdx === retIdx, BTBtype.R,
447  //  Mux(jmpIdx === jalrIdx, BTBtype.I,
448  //  Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
449  val firstHist = inLatch.btbPred.bits.hist(0)
450  // there may be several notTaken branches before the first jump instruction,
451  // so we need to calculate how many zeroes should each instruction shift in its global history.
452  // each history is exclusive of instruction's own jump direction.
453  val histShift = Wire(Vec(FetchWidth, UInt(log2Up(FetchWidth).W)))
454  val shift = Wire(Vec(FetchWidth, Vec(FetchWidth, UInt(1.W))))
455  (0 until FetchWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), FetchWidth)).asTypeOf(Vec(FetchWidth, UInt(1.W))))
456  for (j <- 0 until FetchWidth) {
457    var tmp = 0.U
458    for (i <- 0 until FetchWidth) {
459      tmp = tmp + shift(i)(j)
460    }
461    histShift(j) := tmp
462  }
463  (0 until FetchWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
464  // save ras checkpoint info
465  io.out.bits.rasSp := sp.value
466  io.out.bits.rasTopCtr := rasTop.ctr
467
468  // flush BPU and redirect when target differs from the target predicted in Stage1
469  io.out.bits.redirect := !inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
470    inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target
471  io.flushBPU := io.out.bits.redirect && io.out.valid
472
473  // speculative update RAS
474  val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
475  rasWrite.retAddr := inLatch.pc + OHToUInt(callIdx) << 2.U + 4.U
476  val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
477  rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
478  when (io.out.valid) {
479    when (jmpIdx === callIdx) {
480      ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite
481      when (allocNewEntry) { sp.value := sp.value + 1.U }
482    }.elsewhen (jmpIdx === retIdx) {
483      when (rasTop.ctr === 1.U) {
484        sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U)
485      }.otherwise {
486        ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry())
487      }
488    }
489  }
490  // use checkpoint to recover RAS
491  val recoverSp = io.redirectInfo.redirect.rasSp
492  val recoverCtr = io.redirectInfo.redirect.rasTopCtr
493  when (io.redirectInfo.valid && io.redirectInfo.misPred) {
494    sp.value := recoverSp
495    ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
496  }
497
498  // roll back global history in S1 if S3 redirects
499  io.s1RollBackHist := PriorityMux(jmpIdx, io.out.bits.hist)
500}
501
502class BPU extends XSModule {
503  val io = IO(new Bundle() {
504    // from backend
505    // flush pipeline if misPred and update bpu based on redirect signals from brq
506    val redirectInfo = Flipped(new RedirectInfo)
507
508    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
509
510    val btbOut = ValidIO(new BranchPrediction)
511    val tageOut = ValidIO(new BranchPrediction)
512
513    // predecode info from icache
514    // TODO: simplify this after implement predecode unit
515    val predecode = Flipped(ValidIO(new Predecode))
516  })
517
518  val s1 = Module(new BPUStage1)
519  val s2 = Module(new BPUStage2)
520  val s3 = Module(new BPUStage3)
521
522  s1.io.redirectInfo <> io.redirectInfo
523  s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
524  s1.io.in.pc.valid := io.in.pc.valid
525  s1.io.in.pc.bits <> io.in.pc.bits
526  io.btbOut <> s1.io.s1OutPred
527  s1.io.s3RollBackHist := s3.io.s1RollBackHist
528
529  s1.io.out <> s2.io.in
530  s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
531
532  s2.io.out <> s3.io.in
533  s3.io.flush := io.redirectInfo.flush()
534  s3.io.predecode <> io.predecode
535  io.tageOut <> s3.io.out
536  s3.io.redirectInfo <> io.redirectInfo
537}