xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision b5d0eb3c6ed0f063361c37135cf553ad99d7a4ee)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import utils._
6import xiangshan._
7import xiangshan.backend.ALUOpType
8import xiangshan.backend.JumpOpType
9import chisel3.util.experimental.BoringUtils
10import xiangshan.backend.decode.XSTrap
11
12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
13  def tagBits = VAddrBits - idxBits - 1
14
15  val tag = UInt(tagBits.W)
16  val idx = UInt(idxBits.W)
17  val offset = UInt(1.W)
18
19  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
20  def getTag(x: UInt) = fromUInt(x).tag
21  def getIdx(x: UInt) = fromUInt(x).idx
22  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
23  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
24}
25
26class Stage1To2IO extends XSBundle {
27  val pc = Output(UInt(VAddrBits.W))
28  val btb = new Bundle {
29    val hits = Output(UInt(PredictWidth.W))
30    val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W)))
31  }
32  val jbtac = new Bundle {
33    val hitIdx = Output(UInt(PredictWidth.W))
34    val target = Output(UInt(VAddrBits.W))
35  }
36  val tage = new Bundle {
37    val hits = Output(UInt(FetchWidth.W))
38    val takens = Output(Vec(FetchWidth, Bool()))
39  }
40  val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W)))
41  val btbPred = ValidIO(new BranchPrediction)
42}
43
44class BPUStage1 extends XSModule {
45  val io = IO(new Bundle() {
46    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
47    // from backend
48    val redirectInfo = Input(new RedirectInfo)
49    // from Stage3
50    val flush = Input(Bool())
51    val s3RollBackHist = Input(UInt(HistoryLength.W))
52    val s3Taken = Input(Bool())
53    // to ifu, quick prediction result
54    val s1OutPred = ValidIO(new BranchPrediction)
55    // to Stage2
56    val out = Decoupled(new Stage1To2IO)
57  })
58
59  io.in.pc.ready := true.B
60
61  // flush Stage1 when io.flush
62  val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
63
64  // global history register
65  val ghr = RegInit(0.U(HistoryLength.W))
66  // modify updateGhr and newGhr when updating ghr
67  val updateGhr = WireInit(false.B)
68  val newGhr = WireInit(0.U(HistoryLength.W))
69  when (updateGhr) { ghr := newGhr }
70  // use hist as global history!!!
71  val hist = Mux(updateGhr, newGhr, ghr)
72
73  // Tage predictor
74  val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE)
75  tage.io.req.valid := io.in.pc.fire()
76  tage.io.req.bits.pc := io.in.pc.bits
77  tage.io.req.bits.hist := hist
78  tage.io.redirectInfo <> io.redirectInfo
79  io.out.bits.tage <> tage.io.out
80  // io.s1OutPred.bits.tageMeta := tage.io.meta
81
82  // latch pc for 1 cycle latency when reading SRAM
83  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire())
84  // TODO: pass real mask in
85  // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire())
86  val maskLatch = Fill(PredictWidth, 1.U(1.W))
87
88  val r = io.redirectInfo.redirect
89  val updateFetchpc = r.pc - (r.fetchIdx << 1.U)
90  // BTB
91  val btb = Module(new BTB)
92  btb.io.in.pc <> io.in.pc
93  btb.io.in.pcLatch := pcLatch
94  // TODO: pass real mask in
95  btb.io.in.mask := Fill(PredictWidth, 1.U(1.W))
96  btb.io.redirectValid := io.redirectInfo.valid
97  btb.io.flush := io.flush
98
99  // btb.io.update.fetchPC := updateFetchpc
100  // btb.io.update.fetchIdx := r.fetchIdx
101  btb.io.update.pc := r.pc
102  btb.io.update.hit := r.btbHit
103  btb.io.update.misPred := io.redirectInfo.misPred
104  // btb.io.update.writeWay := r.btbVictimWay
105  btb.io.update.oldCtr := r.btbPredCtr
106  btb.io.update.taken := r.taken
107  btb.io.update.target := r.brTarget
108  btb.io.update.btbType := r.btbType
109  // TODO: add RVC logic
110  btb.io.update.isRVC := r.isRVC
111
112  // val btbHit = btb.io.out.hit
113  val btbTaken = btb.io.out.taken
114  val btbTakenIdx = btb.io.out.takenIdx
115  val btbTakenTarget = btb.io.out.target
116  // val btbWriteWay = btb.io.out.writeWay
117  val btbNotTakens = btb.io.out.notTakens
118  val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred))
119  val btbValids = btb.io.out.hits
120  val btbTargets = VecInit(btb.io.out.dEntries.map(_.target))
121  val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType))
122  val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC))
123
124
125  val jbtac = Module(new JBTAC)
126  jbtac.io.in.pc <> io.in.pc
127  jbtac.io.in.pcLatch := pcLatch
128  // TODO: pass real mask in
129  jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W))
130  jbtac.io.in.hist := hist
131  jbtac.io.redirectValid := io.redirectInfo.valid
132  jbtac.io.flush := io.flush
133
134  jbtac.io.update.fetchPC := updateFetchpc
135  jbtac.io.update.fetchIdx := r.fetchIdx
136  jbtac.io.update.misPred := io.redirectInfo.misPred
137  jbtac.io.update.btbType := r.btbType
138  jbtac.io.update.target := r.target
139  jbtac.io.update.hist := r.hist
140  jbtac.io.update.isRVC := r.isRVC
141
142  val jbtacHit = jbtac.io.out.hit
143  val jbtacTarget = jbtac.io.out.target
144  val jbtacHitIdx = jbtac.io.out.hitIdx
145
146  // calculate global history of each instr
147  val firstHist = RegNext(hist)
148  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
149  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
150  (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
151  for (j <- 0 until PredictWidth) {
152    var tmp = 0.U
153    for (i <- 0 until PredictWidth) {
154      tmp = tmp + shift(i)(j)
155    }
156    histShift(j) := tmp
157  }
158  (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i))
159
160  // update ghr
161  updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool
162  val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx))
163  val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx))
164  // if backend redirects, restore history from backend;
165  // if stage3 redirects, restore history from stage3;
166  // if stage1 redirects, speculatively update history;
167  // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly
168  newGhr := Mux(io.redirectInfo.flush(),    (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken),
169            Mux(io.flush,                   Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist),
170            Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U),
171                                            io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch))))
172
173  // redirect based on BTB and JBTAC
174  val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth)
175  io.out.valid := RegNext(io.in.pc.fire()) && !io.flush
176
177  io.s1OutPred.valid := io.out.valid
178  io.s1OutPred.bits.redirect := btbTaken || jbtacHit
179  io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch,
180                                  Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth),
181                                  LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
182  io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget))
183  io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump
184  // io.s1OutPred.bits.btbVictimWay := btbWriteWay
185  io.s1OutPred.bits.predCtr := btbCtrs
186  io.s1OutPred.bits.btbHit := btbValids
187  io.s1OutPred.bits.tageMeta := DontCare // TODO: enableBPD
188  io.s1OutPred.bits.rasSp := DontCare
189  io.s1OutPred.bits.rasTopCtr := DontCare
190
191  io.out.bits.pc := pcLatch
192  io.out.bits.btb.hits := btbValids.asUInt
193  (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i))
194  io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U)
195  io.out.bits.jbtac.target := jbtacTarget
196  // TODO: we don't need this repeatedly!
197  io.out.bits.hist := io.s1OutPred.bits.hist
198  io.out.bits.btbPred := io.s1OutPred
199
200
201
202  // debug info
203  XSDebug("in:(%d %d)   pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist)
204  XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n",
205    io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target)
206  XSDebug(io.flush && io.redirectInfo.flush(),
207    "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n",
208    r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException)
209  XSDebug(io.flush && !io.redirectInfo.flush(),
210    "flush from Stage3:  s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist)
211
212}
213
214class Stage2To3IO extends Stage1To2IO {
215}
216
217class BPUStage2 extends XSModule {
218  val io = IO(new Bundle() {
219    // flush from Stage3
220    val flush = Input(Bool())
221    val in = Flipped(Decoupled(new Stage1To2IO))
222    val out = Decoupled(new Stage2To3IO)
223  })
224
225  // flush Stage2 when Stage3 or banckend redirects
226  val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
227  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
228  when (io.in.fire()) { inLatch := io.in.bits }
229  val validLatch = RegInit(false.B)
230  when (io.flush) {
231    validLatch := false.B
232  }.elsewhen (io.in.fire()) {
233    validLatch := true.B
234  }.elsewhen (io.out.fire()) {
235    validLatch := false.B
236  }
237
238  io.out.valid := !io.flush && !flushS2 && validLatch
239  io.in.ready := !validLatch || io.out.fire()
240
241  // do nothing
242  io.out.bits := inLatch
243
244  // debug info
245  XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n",
246    io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc)
247  XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc)
248  XSDebug(io.flush, "flush!!!\n")
249}
250
251class BPUStage3 extends XSModule {
252  val io = IO(new Bundle() {
253    val flush = Input(Bool())
254    val in = Flipped(Decoupled(new Stage2To3IO))
255    val out = Decoupled(new BranchPrediction)
256    // from icache
257    val predecode = Flipped(ValidIO(new Predecode))
258    // from backend
259    val redirectInfo = Input(new RedirectInfo)
260    // to Stage1 and Stage2
261    val flushBPU = Output(Bool())
262    // to Stage1, restore ghr in stage1 when flushBPU is valid
263    val s1RollBackHist = Output(UInt(HistoryLength.W))
264    val s3Taken = Output(Bool())
265  })
266
267  val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
268  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
269  val validLatch = RegInit(false.B)
270  val predecodeLatch = RegInit(0.U.asTypeOf(io.predecode.bits))
271  val predecodeValidLatch = RegInit(false.B)
272  when (io.in.fire()) { inLatch := io.in.bits }
273  when (io.flush) {
274    validLatch := false.B
275  }.elsewhen (io.in.fire()) {
276    validLatch := true.B
277  }.elsewhen (io.out.fire()) {
278    validLatch := false.B
279  }
280
281  when (io.predecode.valid) { predecodeLatch := io.predecode.bits }
282  when (io.flush || io.out.fire()) {
283    predecodeValidLatch := false.B
284  }.elsewhen (io.predecode.valid) {
285    predecodeValidLatch := true.B
286  }
287
288  val predecodeValid = io.predecode.valid || predecodeValidLatch
289  val predecode = Mux(io.predecode.valid, io.predecode.bits, predecodeLatch)
290  io.out.valid := validLatch && predecodeValid && !flushS3 && !io.flush
291  io.in.ready := !validLatch || io.out.fire()
292
293  // RAS
294  // TODO: split retAddr and ctr
295  def rasEntry() = new Bundle {
296    val retAddr = UInt(VAddrBits.W)
297    val ctr = UInt(8.W) // layer of nested call functions
298  }
299  val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
300  val sp = Counter(RasSize)
301  val rasTop = ras(sp.value)
302  val rasTopAddr = rasTop.retAddr
303
304  // get the first taken branch/jal/call/jalr/ret in a fetch line
305  // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded.
306  // brNotTakenIdx indicates all the not-taken branches before the first jump instruction.
307  val brIdx = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & predecode.mask
308  val brTakenIdx = if(EnableBPD) {
309    LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth)
310  } else {
311    LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth)
312  }
313  // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache
314  val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & predecode.mask, PredictWidth)
315  val callIdx = LowestBit(inLatch.btb.hits & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth)
316  val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth)
317  val retIdx = LowestBit(predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth)
318
319  val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth)
320  val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & (
321    if(EnableBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)
322    else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt))
323
324  val lateJump = jmpIdx === HighestBit(predecode.mask, PredictWidth) && !predecode.isRVC(OHToUInt(jmpIdx))
325
326  io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(predecode.mask) << 1.U),
327                        Mux(jmpIdx === retIdx, rasTopAddr,
328                        Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
329                        PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here
330
331  io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, predecode.mask,
332                            Mux(!predecode.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth),
333                            LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
334
335  // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay
336  io.out.bits.lateJump := lateJump
337  io.out.bits.predCtr := inLatch.btbPred.bits.predCtr
338  io.out.bits.btbHit := inLatch.btbPred.bits.btbHit
339  io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
340  //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R,
341  //  Mux(jmpIdx === jalrIdx, BTBtype.I,
342  //  Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
343  val firstHist = inLatch.btbPred.bits.hist(0)
344  // there may be several notTaken branches before the first jump instruction,
345  // so we need to calculate how many zeroes should each instruction shift in its global history.
346  // each history is exclusive of instruction's own jump direction.
347  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
348  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
349  (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
350  for (j <- 0 until PredictWidth) {
351    var tmp = 0.U
352    for (i <- 0 until PredictWidth) {
353      tmp = tmp + shift(i)(j)
354    }
355    histShift(j) := tmp
356  }
357  (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
358  // save ras checkpoint info
359  io.out.bits.rasSp := sp.value
360  io.out.bits.rasTopCtr := rasTop.ctr
361
362  // flush BPU and redirect when target differs from the target predicted in Stage1
363  // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
364  //   inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target)
365  //   else false.B)
366  io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
367    inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target
368  io.flushBPU := io.out.bits.redirect && io.out.fire()
369
370  // speculative update RAS
371  val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
372  rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, predecode.isRVC), 2.U, 4.U)
373  val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
374  rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
375  when (io.out.fire() && jmpIdx =/= 0.U) {
376    when (jmpIdx === callIdx) {
377      ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite
378      when (allocNewEntry) { sp.value := sp.value + 1.U }
379    }.elsewhen (jmpIdx === retIdx) {
380      when (rasTop.ctr === 1.U) {
381        sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U)
382      }.otherwise {
383        ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry())
384      }
385    }
386  }
387  // use checkpoint to recover RAS
388  val recoverSp = io.redirectInfo.redirect.rasSp
389  val recoverCtr = io.redirectInfo.redirect.rasTopCtr
390  when (io.redirectInfo.flush()) {
391    sp.value := recoverSp
392    ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
393  }
394
395  // roll back global history in S1 if S3 redirects
396  io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx))
397  // whether Stage3 has a taken jump
398  io.s3Taken := jmpIdx.orR.asBool
399
400  // debug info
401  XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
402  XSDebug(io.out.fire(), "out:(%d %d) pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
403    io.out.valid, io.out.ready, inLatch.pc, io.out.bits.redirect, predecode.mask, io.out.bits.instrValid.asUInt, io.out.bits.target)
404  XSDebug("flushS3=%d\n", flushS3)
405  XSDebug("validLatch=%d predecodeValid=%d\n", validLatch, predecodeValid)
406  XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n",
407    brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx)
408
409  // BPU's TEMP Perf Cnt
410  BoringUtils.addSource(io.out.fire(), "MbpS3Cnt")
411  BoringUtils.addSource(io.out.fire() && io.out.bits.redirect, "MbpS3TageRed")
412  BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir")
413  BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect
414              && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar")
415}
416
417class BPU extends XSModule {
418  val io = IO(new Bundle() {
419    // from backend
420    // flush pipeline if misPred and update bpu based on redirect signals from brq
421    val redirectInfo = Input(new RedirectInfo)
422
423    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
424
425    val btbOut = ValidIO(new BranchPrediction)
426    val tageOut = Decoupled(new BranchPrediction)
427
428    // predecode info from icache
429    // TODO: simplify this after implement predecode unit
430    val predecode = Flipped(ValidIO(new Predecode))
431  })
432
433  val s1 = Module(new BPUStage1)
434  val s2 = Module(new BPUStage2)
435  val s3 = Module(new BPUStage3)
436
437  s1.io.redirectInfo <> io.redirectInfo
438  s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
439  s1.io.in.pc.valid := io.in.pc.valid
440  s1.io.in.pc.bits <> io.in.pc.bits
441  io.btbOut <> s1.io.s1OutPred
442  s1.io.s3RollBackHist := s3.io.s1RollBackHist
443  s1.io.s3Taken := s3.io.s3Taken
444
445  s1.io.out <> s2.io.in
446  s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
447
448  s2.io.out <> s3.io.in
449  s3.io.flush := io.redirectInfo.flush()
450  s3.io.predecode <> io.predecode
451  io.tageOut <> s3.io.out
452  s3.io.redirectInfo <> io.redirectInfo
453
454  // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter
455  val bpuPerfCntList = List(
456    ("MbpInstr","         "),
457    ("MbpRight","         "),
458    ("MbpWrong","         "),
459    ("MbpBRight","        "),
460    ("MbpBWrong","        "),
461    ("MbpJRight","        "),
462    ("MbpJWrong","        "),
463    ("MbpIRight","        "),
464    ("MbpIWrong","        "),
465    ("MbpRRight","        "),
466    ("MbpRWrong","        "),
467    ("MbpS3Cnt","         "),
468    ("MbpS3TageRed","     "),
469    ("MbpS3TageRedDir","  "),
470    ("MbpS3TageRedTar","  ")
471  )
472
473  val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W)))
474  val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B))
475  (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}}
476
477  for(i <- bpuPerfCntList.indices) {
478    BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1)
479  }
480
481  val xsTrap = WireInit(false.B)
482  BoringUtils.addSink(xsTrap, "XSTRAP_BPU")
483
484  // if (!p.FPGAPlatform) {
485    when (xsTrap) {
486      printf("=================BPU's PerfCnt================\n")
487      for(i <- bpuPerfCntList.indices) {
488        printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i))
489      }
490    }
491  // }
492}