xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision efa0419616e95665b4410f2f61ec54b5fd3e29bd)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import utils._
6import xiangshan._
7import xiangshan.backend.ALUOpType
8import xiangshan.backend.JumpOpType
9import chisel3.util.experimental.BoringUtils
10import xiangshan.backend.decode.XSTrap
11
12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
13  def tagBits = VAddrBits - idxBits - 1
14
15  val tag = UInt(tagBits.W)
16  val idx = UInt(idxBits.W)
17  val offset = UInt(1.W)
18
19  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
20  def getTag(x: UInt) = fromUInt(x).tag
21  def getIdx(x: UInt) = fromUInt(x).idx
22  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
23  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
24}
25
26class Stage1To2IO extends XSBundle {
27  val pc = Output(UInt(VAddrBits.W))
28  val btb = new Bundle {
29    val hits = Output(UInt(PredictWidth.W))
30    val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W)))
31  }
32  val jbtac = new Bundle {
33    val hitIdx = Output(UInt(PredictWidth.W))
34    val target = Output(UInt(VAddrBits.W))
35  }
36  val tage = new Bundle {
37    val hits = Output(UInt(FetchWidth.W))
38    val takens = Output(Vec(FetchWidth, Bool()))
39  }
40  val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W)))
41  val btbPred = ValidIO(new BranchPrediction)
42}
43
44class BPUStage1 extends XSModule {
45  val io = IO(new Bundle() {
46    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
47    // from backend
48    val redirectInfo = Input(new RedirectInfo)
49    // from Stage3
50    val flush = Input(Bool())
51    val s3RollBackHist = Input(UInt(HistoryLength.W))
52    val s3Taken = Input(Bool())
53    // to ifu, quick prediction result
54    val s1OutPred = ValidIO(new BranchPrediction)
55    // to Stage2
56    val out = Decoupled(new Stage1To2IO)
57  })
58
59  io.in.pc.ready := true.B
60
61  // flush Stage1 when io.flush
62  val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
63  val s1OutPredLatch = RegEnable(io.s1OutPred.bits, RegNext(io.in.pc.fire()))
64  val outLatch = RegEnable(io.out.bits, RegNext(io.in.pc.fire()))
65
66  val s1Valid = RegInit(false.B)
67  when (io.flush) {
68    s1Valid := false.B
69  }.elsewhen (io.in.pc.fire()) {
70    s1Valid := true.B
71  }.elsewhen (io.out.fire()) {
72    s1Valid := false.B
73  }
74  io.out.valid := s1Valid
75
76
77  // global history register
78  val ghr = RegInit(0.U(HistoryLength.W))
79  // modify updateGhr and newGhr when updating ghr
80  val updateGhr = WireInit(false.B)
81  val newGhr = WireInit(0.U(HistoryLength.W))
82  when (updateGhr) { ghr := newGhr }
83  // use hist as global history!!!
84  val hist = Mux(updateGhr, newGhr, ghr)
85
86  // Tage predictor
87  val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE)
88  tage.io.req.valid := io.in.pc.fire()
89  tage.io.req.bits.pc := io.in.pc.bits
90  tage.io.req.bits.hist := hist
91  tage.io.redirectInfo <> io.redirectInfo
92  // io.s1OutPred.bits.tageMeta := tage.io.meta
93
94  // latch pc for 1 cycle latency when reading SRAM
95  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire())
96  // TODO: pass real mask in
97  // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire())
98  val maskLatch = Fill(PredictWidth, 1.U(1.W))
99
100  val r = io.redirectInfo.redirect
101  val updateFetchpc = r.pc - (r.fetchIdx << 1.U)
102  // BTB
103  val btb = Module(new BTB)
104  btb.io.in.pc <> io.in.pc
105  btb.io.in.pcLatch := pcLatch
106  // TODO: pass real mask in
107  btb.io.in.mask := Fill(PredictWidth, 1.U(1.W))
108  btb.io.redirectValid := io.redirectInfo.valid
109  btb.io.flush := io.flush
110
111  // btb.io.update.fetchPC := updateFetchpc
112  // btb.io.update.fetchIdx := r.fetchIdx
113  btb.io.update.pc := r.pc
114  btb.io.update.hit := r.btbHit
115  btb.io.update.misPred := io.redirectInfo.misPred
116  // btb.io.update.writeWay := r.btbVictimWay
117  btb.io.update.oldCtr := r.btbPredCtr
118  btb.io.update.taken := r.taken
119  btb.io.update.target := r.brTarget
120  btb.io.update.btbType := r.btbType
121  // TODO: add RVC logic
122  btb.io.update.isRVC := r.isRVC
123
124  // val btbHit = btb.io.out.hit
125  val btbTaken = btb.io.out.taken
126  val btbTakenIdx = btb.io.out.takenIdx
127  val btbTakenTarget = btb.io.out.target
128  // val btbWriteWay = btb.io.out.writeWay
129  val btbNotTakens = btb.io.out.notTakens
130  val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred))
131  val btbValids = btb.io.out.hits
132  val btbTargets = VecInit(btb.io.out.dEntries.map(_.target))
133  val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType))
134  val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC))
135
136
137  val jbtac = Module(new JBTAC)
138  jbtac.io.in.pc <> io.in.pc
139  jbtac.io.in.pcLatch := pcLatch
140  // TODO: pass real mask in
141  jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W))
142  jbtac.io.in.hist := hist
143  jbtac.io.redirectValid := io.redirectInfo.valid
144  jbtac.io.flush := io.flush
145
146  jbtac.io.update.fetchPC := updateFetchpc
147  jbtac.io.update.fetchIdx := r.fetchIdx
148  jbtac.io.update.misPred := io.redirectInfo.misPred
149  jbtac.io.update.btbType := r.btbType
150  jbtac.io.update.target := r.target
151  jbtac.io.update.hist := r.hist
152  jbtac.io.update.isRVC := r.isRVC
153
154  val jbtacHit = jbtac.io.out.hit
155  val jbtacTarget = jbtac.io.out.target
156  val jbtacHitIdx = jbtac.io.out.hitIdx
157
158  // calculate global history of each instr
159  val firstHist = RegNext(hist)
160  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
161  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
162  (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
163  for (j <- 0 until PredictWidth) {
164    var tmp = 0.U
165    for (i <- 0 until PredictWidth) {
166      tmp = tmp + shift(i)(j)
167    }
168    histShift(j) := tmp
169  }
170  (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i))
171
172  // update ghr
173  updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool
174  val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx))
175  val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx))
176  // if backend redirects, restore history from backend;
177  // if stage3 redirects, restore history from stage3;
178  // if stage1 redirects, speculatively update history;
179  // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly
180  newGhr := Mux(io.redirectInfo.flush(),    (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken),
181            Mux(io.flush,                   Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist),
182            Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U),
183                                            io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch))))
184
185  // redirect based on BTB and JBTAC
186  val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth)
187
188  // io.out.valid := RegNext(io.in.pc.fire()) && !io.flush
189
190  // io.s1OutPred.valid := io.out.valid
191  io.s1OutPred.valid := io.out.fire()
192  when (RegNext(io.in.pc.fire())) {
193    io.s1OutPred.bits.redirect := btbTaken || jbtacHit
194    io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch,
195                                    Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth),
196                                    LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
197    io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget))
198    io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump
199    (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i))
200    // io.s1OutPred.bits.btbVictimWay := btbWriteWay
201    io.s1OutPred.bits.predCtr := btbCtrs
202    io.s1OutPred.bits.btbHit := btbValids
203    io.s1OutPred.bits.tageMeta := DontCare // TODO: enableBPD
204    io.s1OutPred.bits.rasSp := DontCare
205    io.s1OutPred.bits.rasTopCtr := DontCare
206  }.otherwise {
207    io.s1OutPred.bits := s1OutPredLatch
208  }
209
210  when (RegNext(io.in.pc.fire())) {
211    io.out.bits.pc := pcLatch
212    io.out.bits.btb.hits := btbValids.asUInt
213    (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i))
214    io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U)
215    io.out.bits.jbtac.target := jbtacTarget
216    io.out.bits.tage <> tage.io.out
217    // TODO: we don't need this repeatedly!
218    io.out.bits.hist := io.s1OutPred.bits.hist
219    io.out.bits.btbPred := io.s1OutPred
220  }.otherwise {
221    io.out.bits := outLatch
222  }
223
224
225  // debug info
226  XSDebug("in:(%d %d)   pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist)
227  XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n",
228    io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target)
229  XSDebug(io.flush && io.redirectInfo.flush(),
230    "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n",
231    r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException)
232  XSDebug(io.flush && !io.redirectInfo.flush(),
233    "flush from Stage3:  s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist)
234
235}
236
237class Stage2To3IO extends Stage1To2IO {
238}
239
240class BPUStage2 extends XSModule {
241  val io = IO(new Bundle() {
242    // flush from Stage3
243    val flush = Input(Bool())
244    val in = Flipped(Decoupled(new Stage1To2IO))
245    val out = Decoupled(new Stage2To3IO)
246  })
247
248  // flush Stage2 when Stage3 or banckend redirects
249  val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
250  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
251  when (io.in.fire()) { inLatch := io.in.bits }
252  val validLatch = RegInit(false.B)
253  when (io.flush) {
254    validLatch := false.B
255  }.elsewhen (io.in.fire()) {
256    validLatch := true.B
257  }.elsewhen (io.out.fire()) {
258    validLatch := false.B
259  }
260
261  io.out.valid := !io.flush && !flushS2 && validLatch
262  io.in.ready := !validLatch || io.out.fire()
263
264  // do nothing
265  io.out.bits := inLatch
266
267  // debug info
268  XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n",
269    io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc)
270  XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc)
271  XSDebug(io.flush, "flush!!!\n")
272}
273
274class BPUStage3 extends XSModule {
275  val io = IO(new Bundle() {
276    val flush = Input(Bool())
277    val in = Flipped(Decoupled(new Stage2To3IO))
278    val out = Decoupled(new BranchPrediction)
279    // from icache
280    val predecode = Flipped(ValidIO(new Predecode))
281    // from backend
282    val redirectInfo = Input(new RedirectInfo)
283    // to Stage1 and Stage2
284    val flushBPU = Output(Bool())
285    // to Stage1, restore ghr in stage1 when flushBPU is valid
286    val s1RollBackHist = Output(UInt(HistoryLength.W))
287    val s3Taken = Output(Bool())
288  })
289
290  val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
291  val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
292  val validLatch = RegInit(false.B)
293  val predecodeLatch = RegInit(0.U.asTypeOf(io.predecode.bits))
294  val predecodeValidLatch = RegInit(false.B)
295  when (io.in.fire()) { inLatch := io.in.bits }
296  when (io.flush) {
297    validLatch := false.B
298  }.elsewhen (io.in.fire()) {
299    validLatch := true.B
300  }.elsewhen (io.out.fire()) {
301    validLatch := false.B
302  }
303
304  when (io.predecode.valid) { predecodeLatch := io.predecode.bits }
305  when (io.flush || io.out.fire()) {
306    predecodeValidLatch := false.B
307  }.elsewhen (io.predecode.valid) {
308    predecodeValidLatch := true.B
309  }
310
311  val predecodeValid = io.predecode.valid || predecodeValidLatch
312  val predecode = Mux(io.predecode.valid, io.predecode.bits, predecodeLatch)
313  io.out.valid := validLatch && predecodeValid && !flushS3 && !io.flush
314  io.in.ready := !validLatch || io.out.fire()
315
316  // RAS
317  // TODO: split retAddr and ctr
318  def rasEntry() = new Bundle {
319    val retAddr = UInt(VAddrBits.W)
320    val ctr = UInt(8.W) // layer of nested call functions
321  }
322  val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
323  val sp = Counter(RasSize)
324  val rasTop = ras(sp.value)
325  val rasTopAddr = rasTop.retAddr
326
327  // get the first taken branch/jal/call/jalr/ret in a fetch line
328  // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded.
329  // brNotTakenIdx indicates all the not-taken branches before the first jump instruction.
330  val brIdx = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & predecode.mask
331  val brTakenIdx = if(EnableBPD) {
332    LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth)
333  } else {
334    LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth)
335  }
336  // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache
337  val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & predecode.mask, PredictWidth)
338  val callIdx = LowestBit(inLatch.btb.hits & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth)
339  val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth)
340  val retIdx = LowestBit(predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth)
341
342  val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth)
343  val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & (
344    if(EnableBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)
345    else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt))
346
347  val lateJump = jmpIdx === HighestBit(predecode.mask, PredictWidth) && !predecode.isRVC(OHToUInt(jmpIdx))
348
349  io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(predecode.mask) << 1.U),
350                        Mux(jmpIdx === retIdx, rasTopAddr,
351                        Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
352                        PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here
353
354  io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, predecode.mask,
355                            Mux(!predecode.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth),
356                            LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
357
358  // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay
359  io.out.bits.lateJump := lateJump
360  io.out.bits.predCtr := inLatch.btbPred.bits.predCtr
361  io.out.bits.btbHit := inLatch.btbPred.bits.btbHit
362  io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
363  //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R,
364  //  Mux(jmpIdx === jalrIdx, BTBtype.I,
365  //  Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
366  val firstHist = inLatch.btbPred.bits.hist(0)
367  // there may be several notTaken branches before the first jump instruction,
368  // so we need to calculate how many zeroes should each instruction shift in its global history.
369  // each history is exclusive of instruction's own jump direction.
370  val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
371  val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
372  (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
373  for (j <- 0 until PredictWidth) {
374    var tmp = 0.U
375    for (i <- 0 until PredictWidth) {
376      tmp = tmp + shift(i)(j)
377    }
378    histShift(j) := tmp
379  }
380  (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
381  // save ras checkpoint info
382  io.out.bits.rasSp := sp.value
383  io.out.bits.rasTopCtr := rasTop.ctr
384
385  // flush BPU and redirect when target differs from the target predicted in Stage1
386  // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
387  //   inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target)
388  //   else false.B)
389  io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
390    inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target
391  io.flushBPU := io.out.bits.redirect && io.out.fire()
392
393  // speculative update RAS
394  val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
395  rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, predecode.isRVC), 2.U, 4.U)
396  val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
397  rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
398  when (io.out.fire() && jmpIdx =/= 0.U) {
399    when (jmpIdx === callIdx) {
400      ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite
401      when (allocNewEntry) { sp.value := sp.value + 1.U }
402    }.elsewhen (jmpIdx === retIdx) {
403      when (rasTop.ctr === 1.U) {
404        sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U)
405      }.otherwise {
406        ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry())
407      }
408    }
409  }
410  // use checkpoint to recover RAS
411  val recoverSp = io.redirectInfo.redirect.rasSp
412  val recoverCtr = io.redirectInfo.redirect.rasTopCtr
413  when (io.redirectInfo.flush()) {
414    sp.value := recoverSp
415    ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
416  }
417
418  // roll back global history in S1 if S3 redirects
419  io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx))
420  // whether Stage3 has a taken jump
421  io.s3Taken := jmpIdx.orR.asBool
422
423  // debug info
424  XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
425  XSDebug(io.out.fire(), "out:(%d %d) pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
426    io.out.valid, io.out.ready, inLatch.pc, io.out.bits.redirect, predecode.mask, io.out.bits.instrValid.asUInt, io.out.bits.target)
427  XSDebug("flushS3=%d\n", flushS3)
428  XSDebug("validLatch=%d predecodeValid=%d\n", validLatch, predecodeValid)
429  XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n",
430    brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx)
431
432  // BPU's TEMP Perf Cnt
433  BoringUtils.addSource(io.out.fire(), "MbpS3Cnt")
434  BoringUtils.addSource(io.out.fire() && io.out.bits.redirect, "MbpS3TageRed")
435  BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir")
436  BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect
437              && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar")
438}
439
440class BPU extends XSModule {
441  val io = IO(new Bundle() {
442    // from backend
443    // flush pipeline if misPred and update bpu based on redirect signals from brq
444    val redirectInfo = Input(new RedirectInfo)
445
446    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
447
448    val btbOut = ValidIO(new BranchPrediction)
449    val tageOut = Decoupled(new BranchPrediction)
450
451    // predecode info from icache
452    // TODO: simplify this after implement predecode unit
453    val predecode = Flipped(ValidIO(new Predecode))
454  })
455
456  val s1 = Module(new BPUStage1)
457  val s2 = Module(new BPUStage2)
458  val s3 = Module(new BPUStage3)
459
460  s1.io.redirectInfo <> io.redirectInfo
461  s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
462  s1.io.in.pc.valid := io.in.pc.valid
463  s1.io.in.pc.bits <> io.in.pc.bits
464  io.btbOut <> s1.io.s1OutPred
465  s1.io.s3RollBackHist := s3.io.s1RollBackHist
466  s1.io.s3Taken := s3.io.s3Taken
467
468  s1.io.out <> s2.io.in
469  s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
470
471  s2.io.out <> s3.io.in
472  s3.io.flush := io.redirectInfo.flush()
473  s3.io.predecode <> io.predecode
474  io.tageOut <> s3.io.out
475  s3.io.redirectInfo <> io.redirectInfo
476
477  // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter
478  val bpuPerfCntList = List(
479    ("MbpInstr","         "),
480    ("MbpRight","         "),
481    ("MbpWrong","         "),
482    ("MbpBRight","        "),
483    ("MbpBWrong","        "),
484    ("MbpJRight","        "),
485    ("MbpJWrong","        "),
486    ("MbpIRight","        "),
487    ("MbpIWrong","        "),
488    ("MbpRRight","        "),
489    ("MbpRWrong","        "),
490    ("MbpS3Cnt","         "),
491    ("MbpS3TageRed","     "),
492    ("MbpS3TageRedDir","  "),
493    ("MbpS3TageRedTar","  ")
494  )
495
496  val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W)))
497  val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B))
498  (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}}
499
500  for(i <- bpuPerfCntList.indices) {
501    BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1)
502  }
503
504  val xsTrap = WireInit(false.B)
505  BoringUtils.addSink(xsTrap, "XSTRAP_BPU")
506
507  // if (!p.FPGAPlatform) {
508    when (xsTrap) {
509      printf("=================BPU's PerfCnt================\n")
510      for(i <- bpuPerfCntList.indices) {
511        printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i))
512      }
513    }
514  // }
515}