1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import utils._ 6import xiangshan._ 7import xiangshan.backend.ALUOpType 8import xiangshan.backend.JumpOpType 9import chisel3.util.experimental.BoringUtils 10import xiangshan.backend.decode.XSTrap 11 12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle { 13 def tagBits = VAddrBits - idxBits - 1 14 15 val tag = UInt(tagBits.W) 16 val idx = UInt(idxBits.W) 17 val offset = UInt(1.W) 18 19 def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this) 20 def getTag(x: UInt) = fromUInt(x).tag 21 def getIdx(x: UInt) = fromUInt(x).idx 22 def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0) 23 def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks)) 24} 25 26class Stage1To2IO extends XSBundle { 27 val pc = Output(UInt(VAddrBits.W)) 28 val btb = new Bundle { 29 val hits = Output(UInt(PredictWidth.W)) 30 val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W))) 31 } 32 val jbtac = new Bundle { 33 val hitIdx = Output(UInt(PredictWidth.W)) 34 val target = Output(UInt(VAddrBits.W)) 35 } 36 val tage = new Bundle { 37 val hits = Output(UInt(FetchWidth.W)) 38 val takens = Output(Vec(FetchWidth, Bool())) 39 } 40 val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W))) 41 val btbPred = ValidIO(new BranchPrediction) 42} 43 44class BPUStage1 extends XSModule { 45 val io = IO(new Bundle() { 46 val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) } 47 // from backend 48 val redirectInfo = Input(new RedirectInfo) 49 // from Stage3 50 val flush = Input(Bool()) 51 val s3RollBackHist = Input(UInt(HistoryLength.W)) 52 val s3Taken = Input(Bool()) 53 // to ifu, quick prediction result 54 val s1OutPred = ValidIO(new BranchPrediction) 55 // to Stage2 56 val out = Decoupled(new Stage1To2IO) 57 }) 58 59 io.in.pc.ready := true.B 60 61 // flush Stage1 when io.flush 62 val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true) 63 64 // global history register 65 val ghr = RegInit(0.U(HistoryLength.W)) 66 // modify updateGhr and newGhr when updating ghr 67 val updateGhr = WireInit(false.B) 68 val newGhr = WireInit(0.U(HistoryLength.W)) 69 when (updateGhr) { ghr := newGhr } 70 // use hist as global history!!! 71 val hist = Mux(updateGhr, newGhr, ghr) 72 73 // Tage predictor 74 val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE) 75 tage.io.req.valid := io.in.pc.fire() 76 tage.io.req.bits.pc := io.in.pc.bits 77 tage.io.req.bits.hist := hist 78 tage.io.redirectInfo <> io.redirectInfo 79 io.out.bits.tage <> tage.io.out 80 // io.s1OutPred.bits.tageMeta := tage.io.meta 81 82 // latch pc for 1 cycle latency when reading SRAM 83 val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire()) 84 // TODO: pass real mask in 85 // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire()) 86 val maskLatch = Fill(PredictWidth, 1.U(1.W)) 87 88 val r = io.redirectInfo.redirect 89 val updateFetchpc = r.pc - (r.fetchIdx << 1.U) 90 // BTB 91 val btb = Module(new BTB) 92 btb.io.in.pc <> io.in.pc 93 btb.io.in.pcLatch := pcLatch 94 // TODO: pass real mask in 95 btb.io.in.mask := Fill(PredictWidth, 1.U(1.W)) 96 btb.io.redirectValid := io.redirectInfo.valid 97 btb.io.flush := io.flush 98 99 // btb.io.update.fetchPC := updateFetchpc 100 // btb.io.update.fetchIdx := r.fetchIdx 101 btb.io.update.pc := r.pc 102 btb.io.update.hit := r.btbHit 103 btb.io.update.misPred := io.redirectInfo.misPred 104 // btb.io.update.writeWay := r.btbVictimWay 105 btb.io.update.oldCtr := r.btbPredCtr 106 btb.io.update.taken := r.taken 107 btb.io.update.target := r.brTarget 108 btb.io.update.btbType := r.btbType 109 // TODO: add RVC logic 110 btb.io.update.isRVC := r.isRVC 111 112 // val btbHit = btb.io.out.hit 113 val btbTaken = btb.io.out.taken 114 val btbTakenIdx = btb.io.out.takenIdx 115 val btbTakenTarget = btb.io.out.target 116 // val btbWriteWay = btb.io.out.writeWay 117 val btbNotTakens = btb.io.out.notTakens 118 val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred)) 119 val btbValids = btb.io.out.hits 120 val btbTargets = VecInit(btb.io.out.dEntries.map(_.target)) 121 val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType)) 122 val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC)) 123 124 125 val jbtac = Module(new JBTAC) 126 jbtac.io.in.pc <> io.in.pc 127 jbtac.io.in.pcLatch := pcLatch 128 // TODO: pass real mask in 129 jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W)) 130 jbtac.io.in.hist := hist 131 jbtac.io.redirectValid := io.redirectInfo.valid 132 jbtac.io.flush := io.flush 133 134 jbtac.io.update.fetchPC := updateFetchpc 135 jbtac.io.update.fetchIdx := r.fetchIdx 136 jbtac.io.update.misPred := io.redirectInfo.misPred 137 jbtac.io.update.btbType := r.btbType 138 jbtac.io.update.target := r.target 139 jbtac.io.update.hist := r.hist 140 jbtac.io.update.isRVC := r.isRVC 141 142 val jbtacHit = jbtac.io.out.hit 143 val jbtacTarget = jbtac.io.out.target 144 val jbtacHitIdx = jbtac.io.out.hitIdx 145 146 // calculate global history of each instr 147 val firstHist = RegNext(hist) 148 val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) 149 val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) 150 (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) 151 for (j <- 0 until PredictWidth) { 152 var tmp = 0.U 153 for (i <- 0 until PredictWidth) { 154 tmp = tmp + shift(i)(j) 155 } 156 histShift(j) := tmp 157 } 158 (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i)) 159 160 // update ghr 161 updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool 162 val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx)) 163 val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx)) 164 // if backend redirects, restore history from backend; 165 // if stage3 redirects, restore history from stage3; 166 // if stage1 redirects, speculatively update history; 167 // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly 168 newGhr := Mux(io.redirectInfo.flush(), (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken), 169 Mux(io.flush, Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist), 170 Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U), 171 io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch)))) 172 173 // redirect based on BTB and JBTAC 174 val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth) 175 io.out.valid := RegNext(io.in.pc.fire()) && !io.flush 176 177 io.s1OutPred.valid := io.out.valid 178 io.s1OutPred.bits.redirect := btbTaken || jbtacHit 179 io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch, 180 Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth), 181 LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool())) 182 io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget)) 183 io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump 184 // io.s1OutPred.bits.btbVictimWay := btbWriteWay 185 io.s1OutPred.bits.predCtr := btbCtrs 186 io.s1OutPred.bits.btbHit := btbValids 187 io.s1OutPred.bits.tageMeta := DontCare // TODO: enableBPD 188 io.s1OutPred.bits.rasSp := DontCare 189 io.s1OutPred.bits.rasTopCtr := DontCare 190 191 io.out.bits.pc := pcLatch 192 io.out.bits.btb.hits := btbValids.asUInt 193 (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i)) 194 io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U) 195 io.out.bits.jbtac.target := jbtacTarget 196 // TODO: we don't need this repeatedly! 197 io.out.bits.hist := io.s1OutPred.bits.hist 198 io.out.bits.btbPred := io.s1OutPred 199 200 201 202 // debug info 203 XSDebug("in:(%d %d) pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist) 204 XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n", 205 io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target) 206 XSDebug(io.flush && io.redirectInfo.flush(), 207 "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n", 208 r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException) 209 XSDebug(io.flush && !io.redirectInfo.flush(), 210 "flush from Stage3: s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist) 211 212} 213 214class Stage2To3IO extends Stage1To2IO { 215} 216 217class BPUStage2 extends XSModule { 218 val io = IO(new Bundle() { 219 // flush from Stage3 220 val flush = Input(Bool()) 221 val in = Flipped(Decoupled(new Stage1To2IO)) 222 val out = Decoupled(new Stage2To3IO) 223 }) 224 225 // flush Stage2 when Stage3 or banckend redirects 226 val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) 227 val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) 228 when (io.in.fire()) { inLatch := io.in.bits } 229 val validLatch = RegInit(false.B) 230 when (io.flush) { 231 validLatch := false.B 232 }.elsewhen (io.in.fire()) { 233 validLatch := true.B 234 }.elsewhen (io.out.fire()) { 235 validLatch := false.B 236 } 237 238 io.out.valid := !io.flush && !flushS2 && validLatch 239 io.in.ready := !validLatch || io.out.fire() 240 241 // do nothing 242 io.out.bits := inLatch 243 244 // debug info 245 XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n", 246 io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc) 247 XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc) 248 XSDebug(io.flush, "flush!!!\n") 249} 250 251class BPUStage3 extends XSModule { 252 val io = IO(new Bundle() { 253 val flush = Input(Bool()) 254 val in = Flipped(Decoupled(new Stage2To3IO)) 255 val out = Decoupled(new BranchPrediction) 256 // from icache 257 val predecode = Flipped(ValidIO(new Predecode)) 258 // from backend 259 val redirectInfo = Input(new RedirectInfo) 260 // to Stage1 and Stage2 261 val flushBPU = Output(Bool()) 262 // to Stage1, restore ghr in stage1 when flushBPU is valid 263 val s1RollBackHist = Output(UInt(HistoryLength.W)) 264 val s3Taken = Output(Bool()) 265 }) 266 267 val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) 268 val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) 269 val validLatch = RegInit(false.B) 270 val predecodeLatch = RegInit(0.U.asTypeOf(io.predecode.bits)) 271 val predecodeValidLatch = RegInit(false.B) 272 when (io.in.fire()) { inLatch := io.in.bits } 273 when (io.flush) { 274 validLatch := false.B 275 }.elsewhen (io.in.fire()) { 276 validLatch := true.B 277 }.elsewhen (io.out.fire()) { 278 validLatch := false.B 279 } 280 281 when (io.predecode.valid) { predecodeLatch := io.predecode.bits } 282 when (io.flush || io.out.fire()) { 283 predecodeValidLatch := false.B 284 }.elsewhen (io.predecode.valid) { 285 predecodeValidLatch := true.B 286 } 287 288 val predecodeValid = io.predecode.valid || predecodeValidLatch 289 val predecode = Mux(io.predecode.valid, io.predecode.bits, predecodeLatch) 290 io.out.valid := validLatch && predecodeValid && !flushS3 && !io.flush 291 io.in.ready := !validLatch || io.out.fire() 292 293 // RAS 294 // TODO: split retAddr and ctr 295 def rasEntry() = new Bundle { 296 val retAddr = UInt(VAddrBits.W) 297 val ctr = UInt(8.W) // layer of nested call functions 298 } 299 val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry())))) 300 val sp = Counter(RasSize) 301 val rasTop = ras(sp.value) 302 val rasTopAddr = rasTop.retAddr 303 304 // get the first taken branch/jal/call/jalr/ret in a fetch line 305 // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded. 306 // brNotTakenIdx indicates all the not-taken branches before the first jump instruction. 307 val brIdx = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & predecode.mask 308 val brTakenIdx = if(EnableBPD) { 309 LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth) 310 } else { 311 LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth) 312 } 313 // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache 314 val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & predecode.mask, PredictWidth) 315 val callIdx = LowestBit(inLatch.btb.hits & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth) 316 val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth) 317 val retIdx = LowestBit(predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth) 318 319 val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth) 320 val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & ( 321 if(EnableBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt) 322 else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt)) 323 324 val lateJump = jmpIdx === HighestBit(predecode.mask, PredictWidth) && !predecode.isRVC(OHToUInt(jmpIdx)) 325 326 io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(predecode.mask) << 1.U), 327 Mux(jmpIdx === retIdx, rasTopAddr, 328 Mux(jmpIdx === jalrIdx, inLatch.jbtac.target, 329 PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here 330 331 io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, predecode.mask, 332 Mux(!predecode.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth), 333 LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool())) 334 335 // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay 336 io.out.bits.lateJump := lateJump 337 io.out.bits.predCtr := inLatch.btbPred.bits.predCtr 338 io.out.bits.btbHit := inLatch.btbPred.bits.btbHit 339 io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta 340 //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R, 341 // Mux(jmpIdx === jalrIdx, BTBtype.I, 342 // Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J))) 343 val firstHist = inLatch.btbPred.bits.hist(0) 344 // there may be several notTaken branches before the first jump instruction, 345 // so we need to calculate how many zeroes should each instruction shift in its global history. 346 // each history is exclusive of instruction's own jump direction. 347 val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) 348 val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) 349 (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) 350 for (j <- 0 until PredictWidth) { 351 var tmp = 0.U 352 for (i <- 0 until PredictWidth) { 353 tmp = tmp + shift(i)(j) 354 } 355 histShift(j) := tmp 356 } 357 (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i)) 358 // save ras checkpoint info 359 io.out.bits.rasSp := sp.value 360 io.out.bits.rasTopCtr := rasTop.ctr 361 362 // flush BPU and redirect when target differs from the target predicted in Stage1 363 // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool || 364 // inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target) 365 // else false.B) 366 io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool || 367 inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target 368 io.flushBPU := io.out.bits.redirect && io.out.fire() 369 370 // speculative update RAS 371 val rasWrite = WireInit(0.U.asTypeOf(rasEntry())) 372 rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, predecode.isRVC), 2.U, 4.U) 373 val allocNewEntry = rasWrite.retAddr =/= rasTopAddr 374 rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U) 375 when (io.out.fire() && jmpIdx =/= 0.U) { 376 when (jmpIdx === callIdx) { 377 ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite 378 when (allocNewEntry) { sp.value := sp.value + 1.U } 379 }.elsewhen (jmpIdx === retIdx) { 380 when (rasTop.ctr === 1.U) { 381 sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U) 382 }.otherwise { 383 ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry()) 384 } 385 } 386 } 387 // use checkpoint to recover RAS 388 val recoverSp = io.redirectInfo.redirect.rasSp 389 val recoverCtr = io.redirectInfo.redirect.rasTopCtr 390 when (io.redirectInfo.flush()) { 391 sp.value := recoverSp 392 ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry()) 393 } 394 395 // roll back global history in S1 if S3 redirects 396 io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx)) 397 // whether Stage3 has a taken jump 398 io.s3Taken := jmpIdx.orR.asBool 399 400 // debug info 401 XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc) 402 XSDebug(io.out.fire(), "out:(%d %d) pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n", 403 io.out.valid, io.out.ready, inLatch.pc, io.out.bits.redirect, predecode.mask, io.out.bits.instrValid.asUInt, io.out.bits.target) 404 XSDebug("flushS3=%d\n", flushS3) 405 XSDebug("validLatch=%d predecodeValid=%d\n", validLatch, predecodeValid) 406 XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n", 407 brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx) 408 409 // BPU's TEMP Perf Cnt 410 BoringUtils.addSource(io.out.fire(), "MbpS3Cnt") 411 BoringUtils.addSource(io.out.fire() && io.out.bits.redirect, "MbpS3TageRed") 412 BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir") 413 BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect 414 && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar") 415} 416 417class BPU extends XSModule { 418 val io = IO(new Bundle() { 419 // from backend 420 // flush pipeline if misPred and update bpu based on redirect signals from brq 421 val redirectInfo = Input(new RedirectInfo) 422 423 val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) } 424 425 val btbOut = ValidIO(new BranchPrediction) 426 val tageOut = Decoupled(new BranchPrediction) 427 428 // predecode info from icache 429 // TODO: simplify this after implement predecode unit 430 val predecode = Flipped(ValidIO(new Predecode)) 431 }) 432 433 val s1 = Module(new BPUStage1) 434 val s2 = Module(new BPUStage2) 435 val s3 = Module(new BPUStage3) 436 437 s1.io.redirectInfo <> io.redirectInfo 438 s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush() 439 s1.io.in.pc.valid := io.in.pc.valid 440 s1.io.in.pc.bits <> io.in.pc.bits 441 io.btbOut <> s1.io.s1OutPred 442 s1.io.s3RollBackHist := s3.io.s1RollBackHist 443 s1.io.s3Taken := s3.io.s3Taken 444 445 s1.io.out <> s2.io.in 446 s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush() 447 448 s2.io.out <> s3.io.in 449 s3.io.flush := io.redirectInfo.flush() 450 s3.io.predecode <> io.predecode 451 io.tageOut <> s3.io.out 452 s3.io.redirectInfo <> io.redirectInfo 453 454 // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter 455 val bpuPerfCntList = List( 456 ("MbpInstr"," "), 457 ("MbpRight"," "), 458 ("MbpWrong"," "), 459 ("MbpBRight"," "), 460 ("MbpBWrong"," "), 461 ("MbpJRight"," "), 462 ("MbpJWrong"," "), 463 ("MbpIRight"," "), 464 ("MbpIWrong"," "), 465 ("MbpRRight"," "), 466 ("MbpRWrong"," "), 467 ("MbpS3Cnt"," "), 468 ("MbpS3TageRed"," "), 469 ("MbpS3TageRedDir"," "), 470 ("MbpS3TageRedTar"," ") 471 ) 472 473 val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W))) 474 val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B)) 475 (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}} 476 477 for(i <- bpuPerfCntList.indices) { 478 BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1) 479 } 480 481 val xsTrap = WireInit(false.B) 482 BoringUtils.addSink(xsTrap, "XSTRAP_BPU") 483 484 // if (!p.FPGAPlatform) { 485 when (xsTrap) { 486 printf("=================BPU's PerfCnt================\n") 487 for(i <- bpuPerfCntList.indices) { 488 printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i)) 489 } 490 } 491 // } 492}