1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import utils._ 6import xiangshan._ 7import xiangshan.backend.ALUOpType 8import xiangshan.backend.JumpOpType 9import chisel3.util.experimental.BoringUtils 10import xiangshan.backend.decode.XSTrap 11 12class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle { 13 def tagBits = VAddrBits - idxBits - 1 14 15 val tag = UInt(tagBits.W) 16 val idx = UInt(idxBits.W) 17 val offset = UInt(1.W) 18 19 def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this) 20 def getTag(x: UInt) = fromUInt(x).tag 21 def getIdx(x: UInt) = fromUInt(x).idx 22 def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0) 23 def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks)) 24} 25 26class Stage1To2IO extends XSBundle { 27 val pc = Output(UInt(VAddrBits.W)) 28 val btb = new Bundle { 29 val hits = Output(UInt(PredictWidth.W)) 30 val targets = Output(Vec(PredictWidth, UInt(VAddrBits.W))) 31 } 32 val jbtac = new Bundle { 33 val hitIdx = Output(UInt(PredictWidth.W)) 34 val target = Output(UInt(VAddrBits.W)) 35 } 36 val tage = new Bundle { 37 val hits = Output(UInt(FetchWidth.W)) 38 val takens = Output(Vec(FetchWidth, Bool())) 39 } 40 val hist = Output(Vec(PredictWidth, UInt(HistoryLength.W))) 41 val btbPred = ValidIO(new BranchPrediction) 42} 43 44class BPUStage1 extends XSModule { 45 val io = IO(new Bundle() { 46 val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) } 47 // from backend 48 val redirectInfo = Input(new RedirectInfo) 49 // from Stage3 50 val flush = Input(Bool()) 51 val s3RollBackHist = Input(UInt(HistoryLength.W)) 52 val s3Taken = Input(Bool()) 53 // to ifu, quick prediction result 54 val s1OutPred = ValidIO(new BranchPrediction) 55 // to Stage2 56 val out = Decoupled(new Stage1To2IO) 57 }) 58 59 io.in.pc.ready := true.B 60 61 // flush Stage1 when io.flush 62 val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true) 63 val s1OutPredLatch = RegEnable(io.s1OutPred.bits, RegNext(io.in.pc.fire())) 64 val outLatch = RegEnable(io.out.bits, RegNext(io.in.pc.fire())) 65 66 val s1Valid = RegInit(false.B) 67 when (io.flush) { 68 s1Valid := false.B 69 }.elsewhen (io.in.pc.fire()) { 70 s1Valid := true.B 71 }.elsewhen (io.out.fire()) { 72 s1Valid := false.B 73 } 74 io.out.valid := s1Valid 75 76 77 // global history register 78 val ghr = RegInit(0.U(HistoryLength.W)) 79 // modify updateGhr and newGhr when updating ghr 80 val updateGhr = WireInit(false.B) 81 val newGhr = WireInit(0.U(HistoryLength.W)) 82 when (updateGhr) { ghr := newGhr } 83 // use hist as global history!!! 84 val hist = Mux(updateGhr, newGhr, ghr) 85 86 // Tage predictor 87 val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE) 88 tage.io.req.valid := io.in.pc.fire() 89 tage.io.req.bits.pc := io.in.pc.bits 90 tage.io.req.bits.hist := hist 91 tage.io.redirectInfo <> io.redirectInfo 92 // io.s1OutPred.bits.tageMeta := tage.io.meta 93 94 // latch pc for 1 cycle latency when reading SRAM 95 val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire()) 96 // TODO: pass real mask in 97 // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire()) 98 val maskLatch = Fill(PredictWidth, 1.U(1.W)) 99 100 val r = io.redirectInfo.redirect 101 val updateFetchpc = r.pc - (r.fetchIdx << 1.U) 102 // BTB 103 val btb = Module(new BTB) 104 btb.io.in.pc <> io.in.pc 105 btb.io.in.pcLatch := pcLatch 106 // TODO: pass real mask in 107 btb.io.in.mask := Fill(PredictWidth, 1.U(1.W)) 108 btb.io.redirectValid := io.redirectInfo.valid 109 btb.io.flush := io.flush 110 111 // btb.io.update.fetchPC := updateFetchpc 112 // btb.io.update.fetchIdx := r.fetchIdx 113 btb.io.update.pc := r.pc 114 btb.io.update.hit := r.btbHit 115 btb.io.update.misPred := io.redirectInfo.misPred 116 // btb.io.update.writeWay := r.btbVictimWay 117 btb.io.update.oldCtr := r.btbPredCtr 118 btb.io.update.taken := r.taken 119 btb.io.update.target := r.brTarget 120 btb.io.update.btbType := r.btbType 121 // TODO: add RVC logic 122 btb.io.update.isRVC := r.isRVC 123 124 // val btbHit = btb.io.out.hit 125 val btbTaken = btb.io.out.taken 126 val btbTakenIdx = btb.io.out.takenIdx 127 val btbTakenTarget = btb.io.out.target 128 // val btbWriteWay = btb.io.out.writeWay 129 val btbNotTakens = btb.io.out.notTakens 130 val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred)) 131 val btbValids = btb.io.out.hits 132 val btbTargets = VecInit(btb.io.out.dEntries.map(_.target)) 133 val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType)) 134 val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC)) 135 136 137 val jbtac = Module(new JBTAC) 138 jbtac.io.in.pc <> io.in.pc 139 jbtac.io.in.pcLatch := pcLatch 140 // TODO: pass real mask in 141 jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W)) 142 jbtac.io.in.hist := hist 143 jbtac.io.redirectValid := io.redirectInfo.valid 144 jbtac.io.flush := io.flush 145 146 jbtac.io.update.fetchPC := updateFetchpc 147 jbtac.io.update.fetchIdx := r.fetchIdx 148 jbtac.io.update.misPred := io.redirectInfo.misPred 149 jbtac.io.update.btbType := r.btbType 150 jbtac.io.update.target := r.target 151 jbtac.io.update.hist := r.hist 152 jbtac.io.update.isRVC := r.isRVC 153 154 val jbtacHit = jbtac.io.out.hit 155 val jbtacTarget = jbtac.io.out.target 156 val jbtacHitIdx = jbtac.io.out.hitIdx 157 158 // calculate global history of each instr 159 val firstHist = RegNext(hist) 160 val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) 161 val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) 162 (0 until PredictWidth).map(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) 163 for (j <- 0 until PredictWidth) { 164 var tmp = 0.U 165 for (i <- 0 until PredictWidth) { 166 tmp = tmp + shift(i)(j) 167 } 168 histShift(j) := tmp 169 } 170 (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i)) 171 172 // update ghr 173 updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool 174 val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx)) 175 val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx)) 176 // if backend redirects, restore history from backend; 177 // if stage3 redirects, restore history from stage3; 178 // if stage1 redirects, speculatively update history; 179 // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly 180 newGhr := Mux(io.redirectInfo.flush(), (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken), 181 Mux(io.flush, Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist), 182 Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U), 183 io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch)))) 184 185 // redirect based on BTB and JBTAC 186 val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth) 187 188 // io.out.valid := RegNext(io.in.pc.fire()) && !io.flush 189 190 // io.s1OutPred.valid := io.out.valid 191 io.s1OutPred.valid := io.out.fire() 192 when (RegNext(io.in.pc.fire())) { 193 io.s1OutPred.bits.redirect := btbTaken || jbtacHit 194 io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch, 195 Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth), 196 LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool())) 197 io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget)) 198 io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump 199 (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i)) 200 // io.s1OutPred.bits.btbVictimWay := btbWriteWay 201 io.s1OutPred.bits.predCtr := btbCtrs 202 io.s1OutPred.bits.btbHit := btbValids 203 io.s1OutPred.bits.tageMeta := DontCare // TODO: enableBPD 204 io.s1OutPred.bits.rasSp := DontCare 205 io.s1OutPred.bits.rasTopCtr := DontCare 206 }.otherwise { 207 io.s1OutPred.bits := s1OutPredLatch 208 } 209 210 when (RegNext(io.in.pc.fire())) { 211 io.out.bits.pc := pcLatch 212 io.out.bits.btb.hits := btbValids.asUInt 213 (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i)) 214 io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U) 215 io.out.bits.jbtac.target := jbtacTarget 216 io.out.bits.tage <> tage.io.out 217 // TODO: we don't need this repeatedly! 218 io.out.bits.hist := io.s1OutPred.bits.hist 219 io.out.bits.btbPred := io.s1OutPred 220 }.otherwise { 221 io.out.bits := outLatch 222 } 223 224 225 // debug info 226 XSDebug("in:(%d %d) pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist) 227 XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n", 228 io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target) 229 XSDebug(io.flush && io.redirectInfo.flush(), 230 "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n", 231 r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException) 232 XSDebug(io.flush && !io.redirectInfo.flush(), 233 "flush from Stage3: s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist) 234 235} 236 237class Stage2To3IO extends Stage1To2IO { 238} 239 240class BPUStage2 extends XSModule { 241 val io = IO(new Bundle() { 242 // flush from Stage3 243 val flush = Input(Bool()) 244 val in = Flipped(Decoupled(new Stage1To2IO)) 245 val out = Decoupled(new Stage2To3IO) 246 }) 247 248 // flush Stage2 when Stage3 or banckend redirects 249 val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) 250 val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) 251 when (io.in.fire()) { inLatch := io.in.bits } 252 val validLatch = RegInit(false.B) 253 when (io.flush) { 254 validLatch := false.B 255 }.elsewhen (io.in.fire()) { 256 validLatch := true.B 257 }.elsewhen (io.out.fire()) { 258 validLatch := false.B 259 } 260 261 io.out.valid := !io.flush && !flushS2 && validLatch 262 io.in.ready := !validLatch || io.out.fire() 263 264 // do nothing 265 io.out.bits := inLatch 266 267 // debug info 268 XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n", 269 io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc) 270 XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc) 271 XSDebug(io.flush, "flush!!!\n") 272} 273 274class BPUStage3 extends XSModule { 275 val io = IO(new Bundle() { 276 val flush = Input(Bool()) 277 val in = Flipped(Decoupled(new Stage2To3IO)) 278 val out = Decoupled(new BranchPrediction) 279 // from icache 280 val predecode = Flipped(ValidIO(new Predecode)) 281 // from backend 282 val redirectInfo = Input(new RedirectInfo) 283 // to Stage1 and Stage2 284 val flushBPU = Output(Bool()) 285 // to Stage1, restore ghr in stage1 when flushBPU is valid 286 val s1RollBackHist = Output(UInt(HistoryLength.W)) 287 val s3Taken = Output(Bool()) 288 }) 289 290 val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) 291 val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) 292 val validLatch = RegInit(false.B) 293 val predecodeLatch = RegInit(0.U.asTypeOf(io.predecode.bits)) 294 val predecodeValidLatch = RegInit(false.B) 295 when (io.in.fire()) { inLatch := io.in.bits } 296 when (io.flush) { 297 validLatch := false.B 298 }.elsewhen (io.in.fire()) { 299 validLatch := true.B 300 }.elsewhen (io.out.fire()) { 301 validLatch := false.B 302 } 303 304 when (io.predecode.valid) { predecodeLatch := io.predecode.bits } 305 when (io.flush || io.out.fire()) { 306 predecodeValidLatch := false.B 307 }.elsewhen (io.predecode.valid) { 308 predecodeValidLatch := true.B 309 } 310 311 val predecodeValid = io.predecode.valid || predecodeValidLatch 312 val predecode = Mux(io.predecode.valid, io.predecode.bits, predecodeLatch) 313 io.out.valid := validLatch && predecodeValid && !flushS3 && !io.flush 314 io.in.ready := !validLatch || io.out.fire() 315 316 // RAS 317 // TODO: split retAddr and ctr 318 def rasEntry() = new Bundle { 319 val retAddr = UInt(VAddrBits.W) 320 val ctr = UInt(8.W) // layer of nested call functions 321 } 322 val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry())))) 323 val sp = Counter(RasSize) 324 val rasTop = ras(sp.value) 325 val rasTopAddr = rasTop.retAddr 326 327 // get the first taken branch/jal/call/jalr/ret in a fetch line 328 // brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded. 329 // brNotTakenIdx indicates all the not-taken branches before the first jump instruction. 330 val brIdx = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & predecode.mask 331 val brTakenIdx = if(EnableBPD) { 332 LowestBit(brIdx & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt), PredictWidth) 333 } else { 334 LowestBit(brIdx & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt), PredictWidth) 335 } 336 // TODO: btb doesn't need to hit, jalIdx/callIdx can be calculated based on instructions read in Cache 337 val jalIdx = LowestBit(inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & predecode.mask, PredictWidth) 338 val callIdx = LowestBit(inLatch.btb.hits & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.call }).asUInt), PredictWidth) 339 val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt), PredictWidth) 340 val retIdx = LowestBit(predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt), PredictWidth) 341 342 val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, PredictWidth) 343 val brNotTakenIdx = brIdx & LowerMask(jmpIdx, PredictWidth) & ( 344 if(EnableBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt) 345 else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt)) 346 347 val lateJump = jmpIdx === HighestBit(predecode.mask, PredictWidth) && !predecode.isRVC(OHToUInt(jmpIdx)) 348 349 io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(predecode.mask) << 1.U), 350 Mux(jmpIdx === retIdx, rasTopAddr, 351 Mux(jmpIdx === jalrIdx, inLatch.jbtac.target, 352 PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here 353 354 io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, predecode.mask, 355 Mux(!predecode.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth), 356 LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool())) 357 358 // io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay 359 io.out.bits.lateJump := lateJump 360 io.out.bits.predCtr := inLatch.btbPred.bits.predCtr 361 io.out.bits.btbHit := inLatch.btbPred.bits.btbHit 362 io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta 363 //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R, 364 // Mux(jmpIdx === jalrIdx, BTBtype.I, 365 // Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J))) 366 val firstHist = inLatch.btbPred.bits.hist(0) 367 // there may be several notTaken branches before the first jump instruction, 368 // so we need to calculate how many zeroes should each instruction shift in its global history. 369 // each history is exclusive of instruction's own jump direction. 370 val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) 371 val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) 372 (0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) 373 for (j <- 0 until PredictWidth) { 374 var tmp = 0.U 375 for (i <- 0 until PredictWidth) { 376 tmp = tmp + shift(i)(j) 377 } 378 histShift(j) := tmp 379 } 380 (0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i)) 381 // save ras checkpoint info 382 io.out.bits.rasSp := sp.value 383 io.out.bits.rasTopCtr := rasTop.ctr 384 385 // flush BPU and redirect when target differs from the target predicted in Stage1 386 // io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool || 387 // inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target) 388 // else false.B) 389 io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool || 390 inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target 391 io.flushBPU := io.out.bits.redirect && io.out.fire() 392 393 // speculative update RAS 394 val rasWrite = WireInit(0.U.asTypeOf(rasEntry())) 395 rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, predecode.isRVC), 2.U, 4.U) 396 val allocNewEntry = rasWrite.retAddr =/= rasTopAddr 397 rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U) 398 when (io.out.fire() && jmpIdx =/= 0.U) { 399 when (jmpIdx === callIdx) { 400 ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite 401 when (allocNewEntry) { sp.value := sp.value + 1.U } 402 }.elsewhen (jmpIdx === retIdx) { 403 when (rasTop.ctr === 1.U) { 404 sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U) 405 }.otherwise { 406 ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry()) 407 } 408 } 409 } 410 // use checkpoint to recover RAS 411 val recoverSp = io.redirectInfo.redirect.rasSp 412 val recoverCtr = io.redirectInfo.redirect.rasTopCtr 413 when (io.redirectInfo.flush()) { 414 sp.value := recoverSp 415 ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry()) 416 } 417 418 // roll back global history in S1 if S3 redirects 419 io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx)) 420 // whether Stage3 has a taken jump 421 io.s3Taken := jmpIdx.orR.asBool 422 423 // debug info 424 XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc) 425 XSDebug(io.out.fire(), "out:(%d %d) pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n", 426 io.out.valid, io.out.ready, inLatch.pc, io.out.bits.redirect, predecode.mask, io.out.bits.instrValid.asUInt, io.out.bits.target) 427 XSDebug("flushS3=%d\n", flushS3) 428 XSDebug("validLatch=%d predecodeValid=%d\n", validLatch, predecodeValid) 429 XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n", 430 brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx) 431 432 // BPU's TEMP Perf Cnt 433 BoringUtils.addSource(io.out.fire(), "MbpS3Cnt") 434 BoringUtils.addSource(io.out.fire() && io.out.bits.redirect, "MbpS3TageRed") 435 BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool), "MbpS3TageRedDir") 436 BoringUtils.addSource(io.out.fire() && (inLatch.btbPred.bits.redirect 437 && jmpIdx.orR.asBool && (io.out.bits.target =/= inLatch.btbPred.bits.target)), "MbpS3TageRedTar") 438} 439 440class BPU extends XSModule { 441 val io = IO(new Bundle() { 442 // from backend 443 // flush pipeline if misPred and update bpu based on redirect signals from brq 444 val redirectInfo = Input(new RedirectInfo) 445 446 val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) } 447 448 val btbOut = ValidIO(new BranchPrediction) 449 val tageOut = Decoupled(new BranchPrediction) 450 451 // predecode info from icache 452 // TODO: simplify this after implement predecode unit 453 val predecode = Flipped(ValidIO(new Predecode)) 454 }) 455 456 val s1 = Module(new BPUStage1) 457 val s2 = Module(new BPUStage2) 458 val s3 = Module(new BPUStage3) 459 460 s1.io.redirectInfo <> io.redirectInfo 461 s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush() 462 s1.io.in.pc.valid := io.in.pc.valid 463 s1.io.in.pc.bits <> io.in.pc.bits 464 io.btbOut <> s1.io.s1OutPred 465 s1.io.s3RollBackHist := s3.io.s1RollBackHist 466 s1.io.s3Taken := s3.io.s3Taken 467 468 s1.io.out <> s2.io.in 469 s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush() 470 471 s2.io.out <> s3.io.in 472 s3.io.flush := io.redirectInfo.flush() 473 s3.io.predecode <> io.predecode 474 io.tageOut <> s3.io.out 475 s3.io.redirectInfo <> io.redirectInfo 476 477 // TODO: temp and ugly code, when perf counters is added( may after adding CSR), please mv the below counter 478 val bpuPerfCntList = List( 479 ("MbpInstr"," "), 480 ("MbpRight"," "), 481 ("MbpWrong"," "), 482 ("MbpBRight"," "), 483 ("MbpBWrong"," "), 484 ("MbpJRight"," "), 485 ("MbpJWrong"," "), 486 ("MbpIRight"," "), 487 ("MbpIWrong"," "), 488 ("MbpRRight"," "), 489 ("MbpRWrong"," "), 490 ("MbpS3Cnt"," "), 491 ("MbpS3TageRed"," "), 492 ("MbpS3TageRedDir"," "), 493 ("MbpS3TageRedTar"," ") 494 ) 495 496 val bpuPerfCnts = List.fill(bpuPerfCntList.length)(RegInit(0.U(XLEN.W))) 497 val bpuPerfCntConds = List.fill(bpuPerfCntList.length)(WireInit(false.B)) 498 (bpuPerfCnts zip bpuPerfCntConds) map { case (cnt, cond) => { when (cond) { cnt := cnt + 1.U }}} 499 500 for(i <- bpuPerfCntList.indices) { 501 BoringUtils.addSink(bpuPerfCntConds(i), bpuPerfCntList(i)._1) 502 } 503 504 val xsTrap = WireInit(false.B) 505 BoringUtils.addSink(xsTrap, "XSTRAP_BPU") 506 507 // if (!p.FPGAPlatform) { 508 when (xsTrap) { 509 printf("=================BPU's PerfCnt================\n") 510 for(i <- bpuPerfCntList.indices) { 511 printf(bpuPerfCntList(i)._1 + bpuPerfCntList(i)._2 + " <- " + "%d\n", bpuPerfCnts(i)) 512 } 513 } 514 // } 515}