xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision 8d22bbae737677f99c224a91c149117f0a1b8788)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import utils._
6import xiangshan._
7import xiangshan.backend.ALUOpType
8import xiangshan.backend.JumpOpType
9import chisel3.util.experimental.BoringUtils
10import xiangshan.backend.decode.XSTrap
11
12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
13  def tagBits = VAddrBits - idxBits - 1
14
15  val tag = UInt(tagBits.W)
16  val idx = UInt(idxBits.W)
17  val offset = UInt(1.W)
18
19  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
20  def getTag(x: UInt) = fromUInt(x).tag
21  def getIdx(x: UInt) = fromUInt(x).idx
22  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
23  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
24}
25
26class Stage1To2IO extends XSBundle {
27  val pc = Output(UInt(VAddrBits.W))
28  val btb = new Bundle {
29    val hits = Output(UInt(PredictWidth.W))
30    val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W)))
31  }
32  val jbtac = new Bundle {
33    val hitIdx = Output(UInt(PredictWidth.W))
34    val target = Output(UInt(VAddrBits.W))
35  }
36  val tage = new Bundle {
37    val hits = Output(UInt(FetchWidth.W))
38    val takens = Output(Vec(FetchWidth, Bool()))
39  }
40  val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W)))
41  val btbPred = ValidIO(new BranchPrediction)
42}
43
44class BPUStage1 extends XSModule {
45  val io = IO(new Bundle() {
46    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
47    // from backend
48    val redirectInfo = Input(new RedirectInfo)
49    // from Stage3
50    val flush = Input(Bool())
51    val s3RollBackHist = Input(UInt(HistoryLength.W))
52    val s3Taken = Input(Bool())
53    // to ifu, quick prediction result
54    val s1OutPred = ValidIO(new BranchPrediction)
55    // to Stage2
56    val out = Decoupled(new Stage1To2IO)
57  })
58
59  io.in.pc.ready := true.B
60
61  // flush Stage1 when io.flush
62  val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
63
64  // global history register
65  val ghr = RegInit(0.U(HistoryLength.W))
66  // modify updateGhr and newGhr when updating ghr
67  val updateGhr = WireInit(false.B)
68  val newGhr = WireInit(0.U(HistoryLength.W))
69  when (updateGhr) { ghr := newGhr }
70  // use hist as global history!!!
71  val hist = Mux(updateGhr, newGhr, ghr)
72
73  // Tage predictor
74  val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE)
75  tage.io.req.valid := io.in.pc.fire()
76  tage.io.req.bits.pc := io.in.pc.bits
77  tage.io.req.bits.hist := hist
78  tage.io.redirectInfo <> io.redirectInfo
79  io.out.bits.tage <> tage.io.out
80  // io.s1OutPred.bits.tageMeta := tage.io.meta
81
82  // latch pc for 1 cycle latency when reading SRAM
83  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire())
84  // TODO: pass real mask in
85  val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire())
86
87  val r = io.redirectInfo.redirect
88  val updateFetchpc = r.pc - (r.fetchIdx << 1.U)
89  // BTB
90  val btb = Module(new BTB)
91  btb.io.in.pc <> io.in.pc
92  btb.io.in.pcLatch := pcLatch
93  // TODO: pass real mask in
94  btb.io.in.mask := Fill(PredictWidth, 1.U(1.W))
95  btb.io.redirectValid := io.redirectInfo.valid
96  btb.io.flush := io.flush
97
98  // btb.io.update.fetchPC := updateFetchpc
99  // btb.io.update.fetchIdx := r.fetchIdx
100  btb.io.update.pc := r.pc
101  btb.io.update.hit := r.btbHit
102  btb.io.update.misPred := io.redirectInfo.misPred
103  // btb.io.update.writeWay := r.btbVictimWay
104  btb.io.update.oldCtr := r.btbPredCtr
105  btb.io.update.taken := r.taken
106  btb.io.update.target := r.brTarget
107  btb.io.update._type := r._type
108  // TODO: add RVC logic
109  btb.io.update.isRVC := r.isRVC
110
111  // val btbHit = btb.io.out.hit
112  val btbTaken = btb.io.out.taken
113  val btbTakenIdx = btb.io.out.takenIdx
114  val btbTakenTarget = btb.io.out.target
115  // val btbWriteWay = btb.io.out.writeWay
116  val btbNotTakens = btb.io.out.notTakens
117  val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred))
118  val btbValids = btb.io.out.hits
119  val btbTargets = VecInit(btb.io.out.dEntries.map(_.target))
120  val btbTypes = VecInit(btb.io.out.dEntries.map(_._type))
121  val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC))
122
123
124  val jbtac = Module(new JBTAC)
125  jbtac.io.in.pc <> io.in.pc
126  jbtac.io.in.pcLatch := pcLatch
127  // TODO: pass real mask in
128  jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W))
129  jbtac.io.in.hist := hist
130  jbtac.io.redirectValid := io.redirectInfo.valid
131  jbtac.io.flush := io.flush
132
133  jbtac.io.update.fetchPC := updateFetchpc
134  jbtac.io.update.fetchIdx := r.fetchIdx
135  jbtac.io.update.misPred := io.redirectInfo.misPred
136  jbtac.io.update._type := r._type
137  jbtac.io.update.target := r.target
138  jbtac.io.update.hist := r.hist
139
140  val jbtacHit = jbtac.io.out.hit
141  val jbtacTarget = jbtac.io.out.target
142  val jbtacHitIdx = jbtac.io.out.hitIdx
143
144  // calculate global history of each instr
145  val firstHist = RegNext(hist)
146  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
147  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
148  (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
149  for (j <- 0 until PredictWidth) {
150    var tmp = 0.U
151    for (i <- 0 until PredictWidth) {
152      tmp = tmp + shift(i)(j)
153    }
154    histShift(j) := tmp
155  }
156  (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i))
157
158  // update ghr
159  updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool
160  val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx))
161  val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx))
162  // if backend redirects, restore history from backend;
163  // if stage3 redirects, restore history from stage3;
164  // if stage1 redirects, speculatively update history;
165  // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly
166  newGhr := Mux(io.redirectInfo.flush(),    (r.hist << 1.U) | !(r._type === BTBtype.B && !r.taken),
167            Mux(io.flush,                   Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist),
168            Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U),
169                                            io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch))))
170
171  // redirect based on BTB and JBTAC
172  val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth)
173  io.out.valid := RegNext(io.in.pc.fire()) && !io.flush
174
175  io.s1OutPred.valid := io.out.valid
176  io.s1OutPred.bits.redirect := btbTaken || jbtacHit
177  io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch,
178                                  Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth),
179                                  LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
180  io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget))
181  io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump
182  // io.s1OutPred.bits.btbVictimWay := btbWriteWay
183  io.s1OutPred.bits.predCtr := btbCtrs
184  io.s1OutPred.bits.btbHit := btbValids
185  io.s1OutPred.bits.tageMeta := DontCare
186  io.s1OutPred.bits.rasSp := DontCare
187  io.s1OutPred.bits.rasTopCtr := DontCare
188
189  io.out.bits.pc := pcLatch
190  io.out.bits.btb.hits := btbValids.asUInt
191  (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i))
192  io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U)
193  io.out.bits.jbtac.target := jbtacTarget
194  // TODO: we don't need this repeatedly!
195  io.out.bits.hist := io.s1OutPred.bits.hist
196  io.out.bits.btbPred := io.s1OutPred
197
198
199
200  // debug info
201  XSDebug("in:(%d %d)   pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist)
202  XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n",
203    io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target)
204  XSDebug(io.flush && io.redirectInfo.flush(),
205    "flush from backend: pc=%x tgt=%x brTgt=%x _type=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n",
206    r.pc, r.target, r.brTarget, r._type, r.taken, r.hist, r.fetchIdx, r.isException)
207  XSDebug(io.flush && !io.redirectInfo.flush(),
208    "flush from Stage3:  s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist)
209
210}
211
212class Stage2To3IO extends Stage1To2IO {
213}
214
215class BPUStage2 extends XSModule {
216  val io = IO(new Bundle() {
217    // flush from Stage3
218    val flush = Input(Bool())
219    val in = Flipped(Decoupled(new Stage1To2IO))
220    val out = Decoupled(new Stage2To3IO)
221  })
222
223  // flush Stage2 when Stage3 or banckend redirects
224  val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
225  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
226  when (io.in.fire()) { inLatch := io.in.bits }
227  val validLatch = RegInit(false.B)
228  when (io.flush) {
229    validLatch := false.B
230  }.elsewhen (io.in.fire()) {
231    validLatch := true.B
232  }.elsewhen (io.out.fire()) {
233    validLatch := false.B
234  }
235
236  io.out.valid := !io.flush && !flushS2 && validLatch
237  io.in.ready := !validLatch || io.out.fire()
238
239  // do nothing
240  io.out.bits := inLatch
241
242  // debug info
243  XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n",
244    io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc)
245  XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc)
246  XSDebug(io.flush, "flush!!!\n")
247}
248
249class BPUStage3 extends XSModule {
250  val io = IO(new Bundle() {
251    val flush = Input(Bool())
252    val in = Flipped(Decoupled(new Stage2To3IO))
253    val out = ValidIO(new BranchPrediction)
254    // from icache
255    val predecode = Flipped(ValidIO(new Predecode))
256    // from backend
257    val redirectInfo = Input(new RedirectInfo)
258    // to Stage1 and Stage2
259    val flushBPU = Output(Bool())
260    // to Stage1, restore ghr in stage1 when flushBPU is valid
261    val s1RollBackHist = Output(UInt(HistoryLength.W))
262    val s3Taken = Output(Bool())
263  })
264
265  val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
266  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
267  val validLatch = RegInit(false.B)
268  when (io.in.fire()) { inLatch := io.in.bits }
269  when (io.flush) {
270    validLatch := false.B
271  }.elsewhen (io.in.fire()) {
272    validLatch := true.B
273  }.elsewhen (io.out.valid) {
274    validLatch := false.B
275  }
276  io.out.valid := validLatch && io.predecode.valid && !flushS3 && !io.flush
277  io.in.ready := !validLatch || io.out.valid
278
279  // RAS
280  // TODO: split retAddr and ctr
281  def rasEntry() = new Bundle {
282    val retAddr = UInt(VAddrBits.W)
283    val ctr = UInt(8.W) // layer of nested call functions
284  }
285  val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
286  val sp = Counter(RasSize)
287  val rasTop = ras(sp.value)
288  val rasTopAddr = rasTop.retAddr
289
290  // get the first taken branch/jal/call/jalr/ret in a fetch line
291  // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded.
292  // brNotTakenIdx indicates all the not-taken branches before the first jump instruction.
293  val brIdx = inLatch.btb.hits & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & io.predecode.bits.mask
294  val brTakenIdx = if(HasBPD) {
295    LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth)
296  } else {
297    LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth)
298  }
299  // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache
300  val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & io.predecode.bits.mask, PredictWidth)
301  val callIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth)
302  val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth)
303  val retIdx = LowestBit(io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth)
304
305  val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth)
306  val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & (
307    if(HasBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)
308    else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt))
309
310  val lateJump = jmpIdx === HighestBit(io.predecode.bits.mask, PredictWidth) && !io.predecode.bits.isRVC(OHToUInt(jmpIdx))
311
312  io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(io.predecode.bits.mask) << 1.U),
313                        Mux(jmpIdx === retIdx, rasTopAddr,
314                        Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
315                        PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here
316
317  io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, io.predecode.bits.mask,
318                            Mux(!io.predecode.bits.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth),
319                            LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
320
321  // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay
322  io.out.bits.predCtr := inLatch.btbPred.bits.predCtr
323  io.out.bits.btbHit := inLatch.btbPred.bits.btbHit
324  io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
325  //io.out.bits._type := Mux(jmpIdx === retIdx, BTBtype.R,
326  //  Mux(jmpIdx === jalrIdx, BTBtype.I,
327  //  Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
328  val firstHist = inLatch.btbPred.bits.hist(0)
329  // there may be several notTaken branches before the first jump instruction,
330  // so we need to calculate how many zeroes should each instruction shift in its global history.
331  // each history is exclusive of instruction's own jump direction.
332  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
333  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
334  (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
335  for (j <- 0 until PredictWidth) {
336    var tmp = 0.U
337    for (i <- 0 until PredictWidth) {
338      tmp = tmp + shift(i)(j)
339    }
340    histShift(j) := tmp
341  }
342  (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
343  // save ras checkpoint info
344  io.out.bits.rasSp := sp.value
345  io.out.bits.rasTopCtr := rasTop.ctr
346
347  // flush BPU and redirect when target differs from the target predicted in Stage1
348  // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
349  //   inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target)
350  //   else false.B)
351  io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
352    inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target
353  io.flushBPU := io.out.bits.redirect && io.out.valid
354
355  // speculative update RAS
356  val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
357  rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, io.predecode.bits.isRVC), 2.U, 4.U)
358  val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
359  rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
360  when (io.out.valid && jmpIdx =/= 0.U) {
361    when (jmpIdx === callIdx) {
362      ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite
363      when (allocNewEntry) { sp.value := sp.value + 1.U }
364    }.elsewhen (jmpIdx === retIdx) {
365      when (rasTop.ctr === 1.U) {
366        sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U)
367      }.otherwise {
368        ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry())
369      }
370    }
371  }
372  // use checkpoint to recover RAS
373  val recoverSp = io.redirectInfo.redirect.rasSp
374  val recoverCtr = io.redirectInfo.redirect.rasTopCtr
375  when (io.redirectInfo.flush()) {
376    sp.value := recoverSp
377    ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
378  }
379
380  // roll back global history in S1 if S3 redirects
381  io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx))
382  // whether Stage3 has a taken jump
383  io.s3Taken := jmpIdx.orR.asBool
384
385  // debug info
386  XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
387  XSDebug(io.out.valid, "out:%d pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
388    io.out.valid, inLatch.pc, io.out.bits.redirect, io.predecode.bits.mask, io.out.bits.instrValid.asUInt, io.out.bits.target)
389  XSDebug("flushS3=%d\n", flushS3)
390  XSDebug("validLatch=%d predecode.valid=%d\n", validLatch, io.predecode.valid)
391  XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n",
392    brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx)
393
394  // BPU's TEMP Perf Cnt
395  BoringUtils.addSource(io.out.valid, "MbpS3Cnt")
396  BoringUtils.addSource(io.out.valid && io.out.bits.redirect, "MbpS3TageRed")
397  BoringUtils.addSource(io.out.valid && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir")
398  BoringUtils.addSource(io.out.valid && (inLatch.btbPred.bits.redirect
399              && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar")
400}
401
402class BPU extends XSModule {
403  val io = IO(new Bundle() {
404    // from backend
405    // flush pipeline if misPred and update bpu based on redirect signals from brq
406    val redirectInfo = Input(new RedirectInfo)
407
408    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
409
410    val btbOut = ValidIO(new BranchPrediction)
411    val tageOut = ValidIO(new BranchPrediction)
412
413    // predecode info from icache
414    // TODO: simplify this after implement predecode unit
415    val predecode = Flipped(ValidIO(new Predecode))
416  })
417
418  val s1 = Module(new BPUStage1)
419  val s2 = Module(new BPUStage2)
420  val s3 = Module(new BPUStage3)
421
422  s1.io.redirectInfo <> io.redirectInfo
423  s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
424  s1.io.in.pc.valid := io.in.pc.valid
425  s1.io.in.pc.bits <> io.in.pc.bits
426  io.btbOut <> s1.io.s1OutPred
427  s1.io.s3RollBackHist := s3.io.s1RollBackHist
428  s1.io.s3Taken := s3.io.s3Taken
429
430  s1.io.out <> s2.io.in
431  s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
432
433  s2.io.out <> s3.io.in
434  s3.io.flush := io.redirectInfo.flush()
435  s3.io.predecode <> io.predecode
436  io.tageOut <> s3.io.out
437  s3.io.redirectInfo <> io.redirectInfo
438
439  // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter
440  val bpuPerfCntList = List(
441    ("MbpInstr","         "),
442    ("MbpRight","         "),
443    ("MbpWrong","         "),
444    ("MbpBRight","        "),
445    ("MbpBWrong","        "),
446    ("MbpJRight","        "),
447    ("MbpJWrong","        "),
448    ("MbpIRight","        "),
449    ("MbpIWrong","        "),
450    ("MbpRRight","        "),
451    ("MbpRWrong","        "),
452    ("MbpS3Cnt","         "),
453    ("MbpS3TageRed","     "),
454    ("MbpS3TageRedDir","  "),
455    ("MbpS3TageRedTar","  ")
456  )
457
458  val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W)))
459  val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B))
460  (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}}
461
462  for(i <- bpuPerfCntList.indices) {
463    BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1)
464  }
465
466  val xsTrap = WireInit(false.B)
467  BoringUtils.addSink(xsTrap, "XSTRAP_BPU")
468
469  // if (!p.FPGAPlatform) {
470    when (xsTrap) {
471      printf("=================BPU's PerfCnt================\n")
472      for(i <- bpuPerfCntList.indices) {
473        printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i))
474      }
475    }
476  // }
477}