xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision 2f931f37633162f8edadae653165cce3c4e6465b)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import utils._
6import xiangshan._
7import xiangshan.backend.ALUOpType
8import xiangshan.backend.JumpOpType
9import chisel3.util.experimental.BoringUtils
10import xiangshan.backend.decode.XSTrap
11
12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
13  def tagBits = VAddrBits - idxBits - 1
14
15  val tag = UInt(tagBits.W)
16  val idx = UInt(idxBits.W)
17  val offset = UInt(1.W)
18
19  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
20  def getTag(x: UInt) = fromUInt(x).tag
21  def getIdx(x: UInt) = fromUInt(x).idx
22  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
23  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
24}
25
26class Stage1To2IO extends XSBundle {
27  val pc = Output(UInt(VAddrBits.W))
28  val btb = new Bundle {
29    val hits = Output(UInt(PredictWidth.W))
30    val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W)))
31  }
32  val jbtac = new Bundle {
33    val hitIdx = Output(UInt(PredictWidth.W))
34    val target = Output(UInt(VAddrBits.W))
35  }
36  val tage = new Bundle {
37    val hits = Output(UInt(FetchWidth.W))
38    val takens = Output(Vec(FetchWidth, Bool()))
39  }
40  val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W)))
41  val btbPred = ValidIO(new BranchPrediction)
42}
43
44class BPUStage1 extends XSModule {
45  val io = IO(new Bundle() {
46    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
47    // from backend
48    val redirectInfo = Input(new RedirectInfo)
49    // from Stage3
50    val flush = Input(Bool())
51    val s3RollBackHist = Input(UInt(HistoryLength.W))
52    val s3Taken = Input(Bool())
53    // to ifu, quick prediction result
54    val s1OutPred = ValidIO(new BranchPrediction)
55    // to Stage2
56    val out = Decoupled(new Stage1To2IO)
57  })
58
59  io.in.pc.ready := true.B
60
61  // flush Stage1 when io.flush
62  val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
63
64  // global history register
65  val ghr = RegInit(0.U(HistoryLength.W))
66  // modify updateGhr and newGhr when updating ghr
67  val updateGhr = WireInit(false.B)
68  val newGhr = WireInit(0.U(HistoryLength.W))
69  when (updateGhr) { ghr := newGhr }
70  // use hist as global history!!!
71  val hist = Mux(updateGhr, newGhr, ghr)
72
73  // Tage predictor
74  val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE)
75  tage.io.req.valid := io.in.pc.fire()
76  tage.io.req.bits.pc := io.in.pc.bits
77  tage.io.req.bits.hist := hist
78  tage.io.redirectInfo <> io.redirectInfo
79  io.out.bits.tage <> tage.io.out
80  // io.s1OutPred.bits.tageMeta := tage.io.meta
81
82  // latch pc for 1 cycle latency when reading SRAM
83  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire())
84  // TODO: pass real mask in
85  // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire())
86  val maskLatch = Fill(PredictWidth, 1.U(1.W))
87
88  val r = io.redirectInfo.redirect
89  val updateFetchpc = r.pc - (r.fetchIdx << 1.U)
90  // BTB
91  val btb = Module(new BTB)
92  btb.io.in.pc <> io.in.pc
93  btb.io.in.pcLatch := pcLatch
94  // TODO: pass real mask in
95  btb.io.in.mask := Fill(PredictWidth, 1.U(1.W))
96  btb.io.redirectValid := io.redirectInfo.valid
97  btb.io.flush := io.flush
98
99  // btb.io.update.fetchPC := updateFetchpc
100  // btb.io.update.fetchIdx := r.fetchIdx
101  btb.io.update.pc := r.pc
102  btb.io.update.hit := r.btbHit
103  btb.io.update.misPred := io.redirectInfo.misPred
104  // btb.io.update.writeWay := r.btbVictimWay
105  btb.io.update.oldCtr := r.btbPredCtr
106  btb.io.update.taken := r.taken
107  btb.io.update.target := r.brTarget
108  btb.io.update._type := r._type
109  // TODO: add RVC logic
110  btb.io.update.isRVC := r.isRVC
111
112  // val btbHit = btb.io.out.hit
113  val btbTaken = btb.io.out.taken
114  val btbTakenIdx = btb.io.out.takenIdx
115  val btbTakenTarget = btb.io.out.target
116  // val btbWriteWay = btb.io.out.writeWay
117  val btbNotTakens = btb.io.out.notTakens
118  val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred))
119  val btbValids = btb.io.out.hits
120  val btbTargets = VecInit(btb.io.out.dEntries.map(_.target))
121  val btbTypes = VecInit(btb.io.out.dEntries.map(_._type))
122  val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC))
123
124
125  val jbtac = Module(new JBTAC)
126  jbtac.io.in.pc <> io.in.pc
127  jbtac.io.in.pcLatch := pcLatch
128  // TODO: pass real mask in
129  jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W))
130  jbtac.io.in.hist := hist
131  jbtac.io.redirectValid := io.redirectInfo.valid
132  jbtac.io.flush := io.flush
133
134  jbtac.io.update.fetchPC := updateFetchpc
135  jbtac.io.update.fetchIdx := r.fetchIdx
136  jbtac.io.update.misPred := io.redirectInfo.misPred
137  jbtac.io.update._type := r._type
138  jbtac.io.update.target := r.target
139  jbtac.io.update.hist := r.hist
140  jbtac.io.update.isRVC := r.isRVC
141
142  val jbtacHit = jbtac.io.out.hit
143  val jbtacTarget = jbtac.io.out.target
144  val jbtacHitIdx = jbtac.io.out.hitIdx
145
146  // calculate global history of each instr
147  val firstHist = RegNext(hist)
148  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
149  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
150  (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
151  for (j <- 0 until PredictWidth) {
152    var tmp = 0.U
153    for (i <- 0 until PredictWidth) {
154      tmp = tmp + shift(i)(j)
155    }
156    histShift(j) := tmp
157  }
158  (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i))
159
160  // update ghr
161  updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool
162  val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx))
163  val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx))
164  // if backend redirects, restore history from backend;
165  // if stage3 redirects, restore history from stage3;
166  // if stage1 redirects, speculatively update history;
167  // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly
168  newGhr := Mux(io.redirectInfo.flush(),    (r.hist << 1.U) | !(r._type === BTBtype.B && !r.taken),
169            Mux(io.flush,                   Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist),
170            Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U),
171                                            io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch))))
172
173  // redirect based on BTB and JBTAC
174  val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth)
175  io.out.valid := RegNext(io.in.pc.fire()) && !io.flush
176
177  io.s1OutPred.valid := io.out.valid
178  io.s1OutPred.bits.redirect := btbTaken || jbtacHit
179  io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch,
180                                  Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth),
181                                  LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
182  io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget))
183  io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump
184  // io.s1OutPred.bits.btbVictimWay := btbWriteWay
185  io.s1OutPred.bits.predCtr := btbCtrs
186  io.s1OutPred.bits.btbHit := btbValids
187  io.s1OutPred.bits.tageMeta := DontCare // TODO: enableBPD
188  io.s1OutPred.bits.rasSp := DontCare
189  io.s1OutPred.bits.rasTopCtr := DontCare
190
191  io.out.bits.pc := pcLatch
192  io.out.bits.btb.hits := btbValids.asUInt
193  (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i))
194  io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U)
195  io.out.bits.jbtac.target := jbtacTarget
196  // TODO: we don't need this repeatedly!
197  io.out.bits.hist := io.s1OutPred.bits.hist
198  io.out.bits.btbPred := io.s1OutPred
199
200
201
202  // debug info
203  XSDebug("in:(%d %d)   pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist)
204  XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n",
205    io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target)
206  XSDebug(io.flush && io.redirectInfo.flush(),
207    "flush from backend: pc=%x tgt=%x brTgt=%x _type=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n",
208    r.pc, r.target, r.brTarget, r._type, r.taken, r.hist, r.fetchIdx, r.isException)
209  XSDebug(io.flush && !io.redirectInfo.flush(),
210    "flush from Stage3:  s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist)
211
212}
213
214class Stage2To3IO extends Stage1To2IO {
215}
216
217class BPUStage2 extends XSModule {
218  val io = IO(new Bundle() {
219    // flush from Stage3
220    val flush = Input(Bool())
221    val in = Flipped(Decoupled(new Stage1To2IO))
222    val out = Decoupled(new Stage2To3IO)
223  })
224
225  // flush Stage2 when Stage3 or banckend redirects
226  val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
227  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
228  when (io.in.fire()) { inLatch := io.in.bits }
229  val validLatch = RegInit(false.B)
230  when (io.flush) {
231    validLatch := false.B
232  }.elsewhen (io.in.fire()) {
233    validLatch := true.B
234  }.elsewhen (io.out.fire()) {
235    validLatch := false.B
236  }
237
238  io.out.valid := !io.flush && !flushS2 && validLatch
239  io.in.ready := !validLatch || io.out.fire()
240
241  // do nothing
242  io.out.bits := inLatch
243
244  // debug info
245  XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n",
246    io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc)
247  XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc)
248  XSDebug(io.flush, "flush!!!\n")
249}
250
251class BPUStage3 extends XSModule {
252  val io = IO(new Bundle() {
253    val flush = Input(Bool())
254    val in = Flipped(Decoupled(new Stage2To3IO))
255    val out = ValidIO(new BranchPrediction)
256    // from icache
257    val predecode = Flipped(ValidIO(new Predecode))
258    // from backend
259    val redirectInfo = Input(new RedirectInfo)
260    // to Stage1 and Stage2
261    val flushBPU = Output(Bool())
262    // to Stage1, restore ghr in stage1 when flushBPU is valid
263    val s1RollBackHist = Output(UInt(HistoryLength.W))
264    val s3Taken = Output(Bool())
265  })
266
267  val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
268  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
269  val validLatch = RegInit(false.B)
270  when (io.in.fire()) { inLatch := io.in.bits }
271  when (io.flush) {
272    validLatch := false.B
273  }.elsewhen (io.in.fire()) {
274    validLatch := true.B
275  }.elsewhen (io.out.valid) {
276    validLatch := false.B
277  }
278  io.out.valid := validLatch && io.predecode.valid && !flushS3 && !io.flush
279  io.in.ready := !validLatch || io.out.valid
280
281  // RAS
282  // TODO: split retAddr and ctr
283  def rasEntry() = new Bundle {
284    val retAddr = UInt(VAddrBits.W)
285    val ctr = UInt(8.W) // layer of nested call functions
286  }
287  val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
288  val sp = Counter(RasSize)
289  val rasTop = ras(sp.value)
290  val rasTopAddr = rasTop.retAddr
291
292  // get the first taken branch/jal/call/jalr/ret in a fetch line
293  // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded.
294  // brNotTakenIdx indicates all the not-taken branches before the first jump instruction.
295  val brIdx = inLatch.btb.hits & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & io.predecode.bits.mask
296  val brTakenIdx = if(EnableBPD) {
297    LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth)
298  } else {
299    LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth)
300  }
301  // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache
302  val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & io.predecode.bits.mask, PredictWidth)
303  val callIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth)
304  val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth)
305  val retIdx = LowestBit(io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth)
306
307  val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth)
308  val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & (
309    if(EnableBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)
310    else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt))
311
312  val lateJump = jmpIdx === HighestBit(io.predecode.bits.mask, PredictWidth) && !io.predecode.bits.isRVC(OHToUInt(jmpIdx))
313
314  io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(io.predecode.bits.mask) << 1.U),
315                        Mux(jmpIdx === retIdx, rasTopAddr,
316                        Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
317                        PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here
318
319  io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, io.predecode.bits.mask,
320                            Mux(!io.predecode.bits.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth),
321                            LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
322
323  // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay
324  io.out.bits.lateJump := lateJump
325  io.out.bits.predCtr := inLatch.btbPred.bits.predCtr
326  io.out.bits.btbHit := inLatch.btbPred.bits.btbHit
327  io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
328  //io.out.bits._type := Mux(jmpIdx === retIdx, BTBtype.R,
329  //  Mux(jmpIdx === jalrIdx, BTBtype.I,
330  //  Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
331  val firstHist = inLatch.btbPred.bits.hist(0)
332  // there may be several notTaken branches before the first jump instruction,
333  // so we need to calculate how many zeroes should each instruction shift in its global history.
334  // each history is exclusive of instruction's own jump direction.
335  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
336  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
337  (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
338  for (j <- 0 until PredictWidth) {
339    var tmp = 0.U
340    for (i <- 0 until PredictWidth) {
341      tmp = tmp + shift(i)(j)
342    }
343    histShift(j) := tmp
344  }
345  (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
346  // save ras checkpoint info
347  io.out.bits.rasSp := sp.value
348  io.out.bits.rasTopCtr := rasTop.ctr
349
350  // flush BPU and redirect when target differs from the target predicted in Stage1
351  // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
352  //   inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target)
353  //   else false.B)
354  io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
355    inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target
356  io.flushBPU := io.out.bits.redirect && io.out.valid
357
358  // speculative update RAS
359  val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
360  rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, io.predecode.bits.isRVC), 2.U, 4.U)
361  val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
362  rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
363  when (io.out.valid && jmpIdx =/= 0.U) {
364    when (jmpIdx === callIdx) {
365      ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite
366      when (allocNewEntry) { sp.value := sp.value + 1.U }
367    }.elsewhen (jmpIdx === retIdx) {
368      when (rasTop.ctr === 1.U) {
369        sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U)
370      }.otherwise {
371        ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry())
372      }
373    }
374  }
375  // use checkpoint to recover RAS
376  val recoverSp = io.redirectInfo.redirect.rasSp
377  val recoverCtr = io.redirectInfo.redirect.rasTopCtr
378  when (io.redirectInfo.flush()) {
379    sp.value := recoverSp
380    ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
381  }
382
383  // roll back global history in S1 if S3 redirects
384  io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx))
385  // whether Stage3 has a taken jump
386  io.s3Taken := jmpIdx.orR.asBool
387
388  // debug info
389  XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
390  XSDebug(io.out.valid, "out:%d pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
391    io.out.valid, inLatch.pc, io.out.bits.redirect, io.predecode.bits.mask, io.out.bits.instrValid.asUInt, io.out.bits.target)
392  XSDebug("flushS3=%d\n", flushS3)
393  XSDebug("validLatch=%d predecode.valid=%d\n", validLatch, io.predecode.valid)
394  XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n",
395    brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx)
396
397  // BPU's TEMP Perf Cnt
398  BoringUtils.addSource(io.out.valid, "MbpS3Cnt")
399  BoringUtils.addSource(io.out.valid && io.out.bits.redirect, "MbpS3TageRed")
400  BoringUtils.addSource(io.out.valid && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir")
401  BoringUtils.addSource(io.out.valid && (inLatch.btbPred.bits.redirect
402              && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar")
403}
404
405class BPU extends XSModule {
406  val io = IO(new Bundle() {
407    // from backend
408    // flush pipeline if misPred and update bpu based on redirect signals from brq
409    val redirectInfo = Input(new RedirectInfo)
410
411    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
412
413    val btbOut = ValidIO(new BranchPrediction)
414    val tageOut = ValidIO(new BranchPrediction)
415
416    // predecode info from icache
417    // TODO: simplify this after implement predecode unit
418    val predecode = Flipped(ValidIO(new Predecode))
419  })
420
421  val s1 = Module(new BPUStage1)
422  val s2 = Module(new BPUStage2)
423  val s3 = Module(new BPUStage3)
424
425  s1.io.redirectInfo <> io.redirectInfo
426  s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
427  s1.io.in.pc.valid := io.in.pc.valid
428  s1.io.in.pc.bits <> io.in.pc.bits
429  io.btbOut <> s1.io.s1OutPred
430  s1.io.s3RollBackHist := s3.io.s1RollBackHist
431  s1.io.s3Taken := s3.io.s3Taken
432
433  s1.io.out <> s2.io.in
434  s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
435
436  s2.io.out <> s3.io.in
437  s3.io.flush := io.redirectInfo.flush()
438  s3.io.predecode <> io.predecode
439  io.tageOut <> s3.io.out
440  s3.io.redirectInfo <> io.redirectInfo
441
442  // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter
443  val bpuPerfCntList = List(
444    ("MbpInstr","         "),
445    ("MbpRight","         "),
446    ("MbpWrong","         "),
447    ("MbpBRight","        "),
448    ("MbpBWrong","        "),
449    ("MbpJRight","        "),
450    ("MbpJWrong","        "),
451    ("MbpIRight","        "),
452    ("MbpIWrong","        "),
453    ("MbpRRight","        "),
454    ("MbpRWrong","        "),
455    ("MbpS3Cnt","         "),
456    ("MbpS3TageRed","     "),
457    ("MbpS3TageRedDir","  "),
458    ("MbpS3TageRedTar","  ")
459  )
460
461  val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W)))
462  val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B))
463  (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}}
464
465  for(i <- bpuPerfCntList.indices) {
466    BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1)
467  }
468
469  val xsTrap = WireInit(false.B)
470  BoringUtils.addSink(xsTrap, "XSTRAP_BPU")
471
472  // if (!p.FPGAPlatform) {
473    when (xsTrap) {
474      printf("=================BPU's PerfCnt================\n")
475      for(i <- bpuPerfCntList.indices) {
476        printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i))
477      }
478    }
479  // }
480}