1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import utils._ 6import xiangshan._ 7import xiangshan.backend.ALUOpType 8import xiangshan.backend.JumpOpType 9import chisel3.util.experimental.BoringUtils 10import xiangshan.backend.decode.XSTrap 11 12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle { 13 def tagBits = VAddrBits - idxBits - 1 14 15 val tag = UInt(tagBits.W) 16 val idx = UInt(idxBits.W) 17 val offset = UInt(1.W) 18 19 def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this) 20 def getTag(x: UInt) = fromUInt(x).tag 21 def getIdx(x: UInt) = fromUInt(x).idx 22 def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0) 23 def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks)) 24} 25 26class Stage1To2IO extends XSBundle { 27 val pc = Output(UInt(VAddrBits.W)) 28 val btb = new Bundle { 29 val hits = Output(UInt(PredictWidth.W)) 30 val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W))) 31 } 32 val jbtac = new Bundle { 33 val hitIdx = Output(UInt(PredictWidth.W)) 34 val target = Output(UInt(VAddrBits.W)) 35 } 36 val tage = new Bundle { 37 val hits = Output(UInt(FetchWidth.W)) 38 val takens = Output(Vec(FetchWidth, Bool())) 39 } 40 val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W))) 41 val btbPred = ValidIO(new BranchPrediction) 42} 43 44class BPUStage1 extends XSModule { 45 val io = IO(new Bundle() { 46 val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) } 47 // from backend 48 val redirectInfo = Input(new RedirectInfo) 49 // from Stage3 50 val flush = Input(Bool()) 51 val s3RollBackHist = Input(UInt(HistoryLength.W)) 52 val s3Taken = Input(Bool()) 53 // to ifu, quick prediction result 54 val s1OutPred = ValidIO(new BranchPrediction) 55 // to Stage2 56 val out = Decoupled(new Stage1To2IO) 57 }) 58 59 io.in.pc.ready := true.B 60 61 // flush Stage1 when io.flush 62 val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true) 63 64 // global history register 65 val ghr = RegInit(0.U(HistoryLength.W)) 66 // modify updateGhr and newGhr when updating ghr 67 val updateGhr = WireInit(false.B) 68 val newGhr = WireInit(0.U(HistoryLength.W)) 69 when (updateGhr) { ghr := newGhr } 70 // use hist as global history!!! 71 val hist = Mux(updateGhr, newGhr, ghr) 72 73 // Tage predictor 74 val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE) 75 tage.io.req.valid := io.in.pc.fire() 76 tage.io.req.bits.pc := io.in.pc.bits 77 tage.io.req.bits.hist := hist 78 tage.io.redirectInfo <> io.redirectInfo 79 io.out.bits.tage <> tage.io.out 80 // io.s1OutPred.bits.tageMeta := tage.io.meta 81 82 // latch pc for 1 cycle latency when reading SRAM 83 val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire()) 84 // TODO: pass real mask in 85 // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire()) 86 val maskLatch = Fill(PredictWidth, 1.U(1.W)) 87 88 val r = io.redirectInfo.redirect 89 val updateFetchpc = r.pc - (r.fetchIdx << 1.U) 90 // BTB 91 val btb = Module(new BTB) 92 btb.io.in.pc <> io.in.pc 93 btb.io.in.pcLatch := pcLatch 94 // TODO: pass real mask in 95 btb.io.in.mask := Fill(PredictWidth, 1.U(1.W)) 96 btb.io.redirectValid := io.redirectInfo.valid 97 btb.io.flush := io.flush 98 99 // btb.io.update.fetchPC := updateFetchpc 100 // btb.io.update.fetchIdx := r.fetchIdx 101 btb.io.update.pc := r.pc 102 btb.io.update.hit := r.btbHit 103 btb.io.update.misPred := io.redirectInfo.misPred 104 // btb.io.update.writeWay := r.btbVictimWay 105 btb.io.update.oldCtr := r.btbPredCtr 106 btb.io.update.taken := r.taken 107 btb.io.update.target := r.brTarget 108 btb.io.update.btbType := r.btbType 109 // TODO: add RVC logic 110 btb.io.update.isRVC := r.isRVC 111 112 // val btbHit = btb.io.out.hit 113 val btbTaken = btb.io.out.taken 114 val btbTakenIdx = btb.io.out.takenIdx 115 val btbTakenTarget = btb.io.out.target 116 // val btbWriteWay = btb.io.out.writeWay 117 val btbNotTakens = btb.io.out.notTakens 118 val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred)) 119 val btbValids = btb.io.out.hits 120 val btbTargets = VecInit(btb.io.out.dEntries.map(_.target)) 121 val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType)) 122 val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC)) 123 124 125 val jbtac = Module(new JBTAC) 126 jbtac.io.in.pc <> io.in.pc 127 jbtac.io.in.pcLatch := pcLatch 128 // TODO: pass real mask in 129 jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W)) 130 jbtac.io.in.hist := hist 131 jbtac.io.redirectValid := io.redirectInfo.valid 132 jbtac.io.flush := io.flush 133 134 jbtac.io.update.fetchPC := updateFetchpc 135 jbtac.io.update.fetchIdx := r.fetchIdx 136 jbtac.io.update.misPred := io.redirectInfo.misPred 137 jbtac.io.update.btbType := r.btbType 138 jbtac.io.update.target := r.target 139 jbtac.io.update.hist := r.hist 140 jbtac.io.update.isRVC := r.isRVC 141 142 val jbtacHit = jbtac.io.out.hit 143 val jbtacTarget = jbtac.io.out.target 144 val jbtacHitIdx = jbtac.io.out.hitIdx 145 146 // calculate global history of each instr 147 val firstHist = RegNext(hist) 148 val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) 149 val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) 150 (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) 151 for (j <- 0 until PredictWidth) { 152 var tmp = 0.U 153 for (i <- 0 until PredictWidth) { 154 tmp = tmp + shift(i)(j) 155 } 156 histShift(j) := tmp 157 } 158 (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i)) 159 160 // update ghr 161 updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool 162 val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx)) 163 val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx)) 164 // if backend redirects, restore history from backend; 165 // if stage3 redirects, restore history from stage3; 166 // if stage1 redirects, speculatively update history; 167 // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly 168 newGhr := Mux(io.redirectInfo.flush(), (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken), 169 Mux(io.flush, Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist), 170 Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U), 171 io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch)))) 172 173 // redirect based on BTB and JBTAC 174 val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth) 175 io.out.valid := RegNext(io.in.pc.fire()) && !io.flush 176 177 io.s1OutPred.valid := io.out.valid 178 io.s1OutPred.bits.redirect := btbTaken || jbtacHit 179 io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch, 180 Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth), 181 LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool())) 182 io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget)) 183 io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump 184 // io.s1OutPred.bits.btbVictimWay := btbWriteWay 185 io.s1OutPred.bits.predCtr := btbCtrs 186 io.s1OutPred.bits.btbHit := btbValids 187 io.s1OutPred.bits.tageMeta := DontCare // TODO: enableBPD 188 io.s1OutPred.bits.rasSp := DontCare 189 io.s1OutPred.bits.rasTopCtr := DontCare 190 191 io.out.bits.pc := pcLatch 192 io.out.bits.btb.hits := btbValids.asUInt 193 (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i)) 194 io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U) 195 io.out.bits.jbtac.target := jbtacTarget 196 // TODO: we don't need this repeatedly! 197 io.out.bits.hist := io.s1OutPred.bits.hist 198 io.out.bits.btbPred := io.s1OutPred 199 200 201 202 // debug info 203 XSDebug("in:(%d %d) pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist) 204 XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n", 205 io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target) 206 XSDebug(io.flush && io.redirectInfo.flush(), 207 "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n", 208 r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException) 209 XSDebug(io.flush && !io.redirectInfo.flush(), 210 "flush from Stage3: s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist) 211 212} 213 214class Stage2To3IO extends Stage1To2IO { 215} 216 217class BPUStage2 extends XSModule { 218 val io = IO(new Bundle() { 219 // flush from Stage3 220 val flush = Input(Bool()) 221 val in = Flipped(Decoupled(new Stage1To2IO)) 222 val out = Decoupled(new Stage2To3IO) 223 }) 224 225 // flush Stage2 when Stage3 or banckend redirects 226 val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) 227 val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) 228 when (io.in.fire()) { inLatch := io.in.bits } 229 val validLatch = RegInit(false.B) 230 when (io.flush) { 231 validLatch := false.B 232 }.elsewhen (io.in.fire()) { 233 validLatch := true.B 234 }.elsewhen (io.out.fire()) { 235 validLatch := false.B 236 } 237 238 io.out.valid := !io.flush && !flushS2 && validLatch 239 io.in.ready := !validLatch || io.out.fire() 240 241 // do nothing 242 io.out.bits := inLatch 243 244 // debug info 245 XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n", 246 io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc) 247 XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc) 248 XSDebug(io.flush, "flush!!!\n") 249} 250 251class BPUStage3 extends XSModule { 252 val io = IO(new Bundle() { 253 val flush = Input(Bool()) 254 val in = Flipped(Decoupled(new Stage2To3IO)) 255 val out = ValidIO(new BranchPrediction) 256 // from icache 257 val predecode = Flipped(ValidIO(new Predecode)) 258 // from backend 259 val redirectInfo = Input(new RedirectInfo) 260 // to Stage1 and Stage2 261 val flushBPU = Output(Bool()) 262 // to Stage1, restore ghr in stage1 when flushBPU is valid 263 val s1RollBackHist = Output(UInt(HistoryLength.W)) 264 val s3Taken = Output(Bool()) 265 }) 266 267 val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) 268 val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) 269 val validLatch = RegInit(false.B) 270 when (io.in.fire()) { inLatch := io.in.bits } 271 when (io.flush) { 272 validLatch := false.B 273 }.elsewhen (io.in.fire()) { 274 validLatch := true.B 275 }.elsewhen (io.out.valid) { 276 validLatch := false.B 277 } 278 io.out.valid := validLatch && io.predecode.valid && !flushS3 && !io.flush 279 io.in.ready := !validLatch || io.out.valid 280 281 // RAS 282 // TODO: split retAddr and ctr 283 def rasEntry() = new Bundle { 284 val retAddr = UInt(VAddrBits.W) 285 val ctr = UInt(8.W) // layer of nested call functions 286 } 287 val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry())))) 288 val sp = Counter(RasSize) 289 val rasTop = ras(sp.value) 290 val rasTopAddr = rasTop.retAddr 291 292 // get the first taken branch/jal/call/jalr/ret in a fetch line 293 // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded. 294 // brNotTakenIdx indicates all the not-taken branches before the first jump instruction. 295 val brIdx = inLatch.btb.hits & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & io.predecode.bits.mask 296 val brTakenIdx = if(EnableBPD) { 297 LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth) 298 } else { 299 LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth) 300 } 301 // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache 302 val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & io.predecode.bits.mask, PredictWidth) 303 val callIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth) 304 val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth) 305 val retIdx = LowestBit(io.predecode.bits.mask & Reverse(Cat(io.predecode.bits.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth) 306 307 val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth) 308 val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & ( 309 if(EnableBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt) 310 else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt)) 311 312 val lateJump = jmpIdx === HighestBit(io.predecode.bits.mask, PredictWidth) && !io.predecode.bits.isRVC(OHToUInt(jmpIdx)) 313 314 io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(io.predecode.bits.mask) << 1.U), 315 Mux(jmpIdx === retIdx, rasTopAddr, 316 Mux(jmpIdx === jalrIdx, inLatch.jbtac.target, 317 PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here 318 319 io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, io.predecode.bits.mask, 320 Mux(!io.predecode.bits.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth), 321 LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool())) 322 323 // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay 324 io.out.bits.lateJump := lateJump 325 io.out.bits.predCtr := inLatch.btbPred.bits.predCtr 326 io.out.bits.btbHit := inLatch.btbPred.bits.btbHit 327 io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta 328 //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R, 329 // Mux(jmpIdx === jalrIdx, BTBtype.I, 330 // Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J))) 331 val firstHist = inLatch.btbPred.bits.hist(0) 332 // there may be several notTaken branches before the first jump instruction, 333 // so we need to calculate how many zeroes should each instruction shift in its global history. 334 // each history is exclusive of instruction's own jump direction. 335 val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) 336 val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) 337 (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) 338 for (j <- 0 until PredictWidth) { 339 var tmp = 0.U 340 for (i <- 0 until PredictWidth) { 341 tmp = tmp + shift(i)(j) 342 } 343 histShift(j) := tmp 344 } 345 (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i)) 346 // save ras checkpoint info 347 io.out.bits.rasSp := sp.value 348 io.out.bits.rasTopCtr := rasTop.ctr 349 350 // flush BPU and redirect when target differs from the target predicted in Stage1 351 // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool || 352 // inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target) 353 // else false.B) 354 io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool || 355 inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target 356 io.flushBPU := io.out.bits.redirect && io.out.valid 357 358 // speculative update RAS 359 val rasWrite = WireInit(0.U.asTypeOf(rasEntry())) 360 rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, io.predecode.bits.isRVC), 2.U, 4.U) 361 val allocNewEntry = rasWrite.retAddr =/= rasTopAddr 362 rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U) 363 when (io.out.valid && jmpIdx =/= 0.U) { 364 when (jmpIdx === callIdx) { 365 ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite 366 when (allocNewEntry) { sp.value := sp.value + 1.U } 367 }.elsewhen (jmpIdx === retIdx) { 368 when (rasTop.ctr === 1.U) { 369 sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U) 370 }.otherwise { 371 ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry()) 372 } 373 } 374 } 375 // use checkpoint to recover RAS 376 val recoverSp = io.redirectInfo.redirect.rasSp 377 val recoverCtr = io.redirectInfo.redirect.rasTopCtr 378 when (io.redirectInfo.flush()) { 379 sp.value := recoverSp 380 ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry()) 381 } 382 383 // roll back global history in S1 if S3 redirects 384 io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx)) 385 // whether Stage3 has a taken jump 386 io.s3Taken := jmpIdx.orR.asBool 387 388 // debug info 389 XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc) 390 XSDebug(io.out.valid, "out:%d pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n", 391 io.out.valid, inLatch.pc, io.out.bits.redirect, io.predecode.bits.mask, io.out.bits.instrValid.asUInt, io.out.bits.target) 392 XSDebug("flushS3=%d\n", flushS3) 393 XSDebug("validLatch=%d predecode.valid=%d\n", validLatch, io.predecode.valid) 394 XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n", 395 brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx) 396 397 // BPU's TEMP Perf Cnt 398 BoringUtils.addSource(io.out.valid, "MbpS3Cnt") 399 BoringUtils.addSource(io.out.valid && io.out.bits.redirect, "MbpS3TageRed") 400 BoringUtils.addSource(io.out.valid && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir") 401 BoringUtils.addSource(io.out.valid && (inLatch.btbPred.bits.redirect 402 && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar") 403} 404 405class BPU extends XSModule { 406 val io = IO(new Bundle() { 407 // from backend 408 // flush pipeline if misPred and update bpu based on redirect signals from brq 409 val redirectInfo = Input(new RedirectInfo) 410 411 val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) } 412 413 val btbOut = ValidIO(new BranchPrediction) 414 val tageOut = ValidIO(new BranchPrediction) 415 416 // predecode info from icache 417 // TODO: simplify this after implement predecode unit 418 val predecode = Flipped(ValidIO(new Predecode)) 419 }) 420 421 val s1 = Module(new BPUStage1) 422 val s2 = Module(new BPUStage2) 423 val s3 = Module(new BPUStage3) 424 425 s1.io.redirectInfo <> io.redirectInfo 426 s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush() 427 s1.io.in.pc.valid := io.in.pc.valid 428 s1.io.in.pc.bits <> io.in.pc.bits 429 io.btbOut <> s1.io.s1OutPred 430 s1.io.s3RollBackHist := s3.io.s1RollBackHist 431 s1.io.s3Taken := s3.io.s3Taken 432 433 s1.io.out <> s2.io.in 434 s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush() 435 436 s2.io.out <> s3.io.in 437 s3.io.flush := io.redirectInfo.flush() 438 s3.io.predecode <> io.predecode 439 io.tageOut <> s3.io.out 440 s3.io.redirectInfo <> io.redirectInfo 441 442 // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter 443 val bpuPerfCntList = List( 444 ("MbpInstr"," "), 445 ("MbpRight"," "), 446 ("MbpWrong"," "), 447 ("MbpBRight"," "), 448 ("MbpBWrong"," "), 449 ("MbpJRight"," "), 450 ("MbpJWrong"," "), 451 ("MbpIRight"," "), 452 ("MbpIWrong"," "), 453 ("MbpRRight"," "), 454 ("MbpRWrong"," "), 455 ("MbpS3Cnt"," "), 456 ("MbpS3TageRed"," "), 457 ("MbpS3TageRedDir"," "), 458 ("MbpS3TageRedTar"," ") 459 ) 460 461 val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W))) 462 val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B)) 463 (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}} 464 465 for(i <- bpuPerfCntList.indices) { 466 BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1) 467 } 468 469 val xsTrap = WireInit(false.B) 470 BoringUtils.addSink(xsTrap, "XSTRAP_BPU") 471 472 // if (!p.FPGAPlatform) { 473 when (xsTrap) { 474 printf("=================BPU's PerfCnt================\n") 475 for(i <- bpuPerfCntList.indices) { 476 printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i)) 477 } 478 } 479 // } 480}