1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import xiangshan._ 6import utils._ 7 8class TableAddr(val idxBits: Int, val wayBanks: Int) extends XSBundle { 9 def wayBankBits = log2Up(wayBanks) 10 def tagBits = VAddrBits - idxBits - wayBankBits - 2 11 12 val tag = UInt(tagBits.W) 13 val idx = UInt(idxBits.W) 14 val bank = UInt(wayBankBits.W) 15 val offset = UInt(2.W) 16 17 def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this) 18 def getIdx(x: UInt) = fromUInt(x).idx 19 // def getLineBank(x: UInt) = getIdx(x)(0) 20 def getWayBank(x: UInt) = fromUInt(x).bank 21 def getTag(x: UInt) = fromUInt(x).tag 22 def getLineOffset(x: UInt) = Cat(fromUInt(x).bank, fromUInt(x).offset) 23} 24 25class BPU extends XSModule { 26 val io = IO(new Bundle() { 27 val flush = Input(Bool()) 28 val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) } 29 val out = new Bundle { val redirect = Valid(UInt(VAddrBits.W)) } 30 }) 31 32 val flush = BoolStopWatch(io.flush, io.in.pc.valid, startHighPriority = true) 33 34 // BTB 35 val btbAddr = new TableAddr(log2Up(BtbSets), BtbWayBanks) 36 def btbMeta() = new Bundle { 37 val valid = Bool() 38 val tag = UInt(btbAddr.tagBits.W) 39 } 40 def btbEntry() = new Bundle { 41 val _type = UInt(2.W) 42 val target = UInt(VAddrBits.W) 43 } 44 45 val meta = RegInit(0.U.asTypeOf(Vec(BtbSets, btbMeta()))) 46 val btb = List.fill(BtbWayBanks)(Module(new SRAMTemplate(btbEntry(), set = BtbSets, shouldReset = true, holdRead = true, singlePort = true))) 47 48 // PHT, which has the same complete association structure as BTB's 49 val pht = List.fill(BtbWayBanks)(Mem(BtbSets, UInt(2.W))) 50 val phtRead = Wire(Vec(FetchWidth, UInt(2.W))) 51 52 val fetchPkgBank = btbAddr.getWayBank(io.in.pc.bits) 53 val fetchPkgAligned = btbAddr.getLineOffset(io.in.pc.bits) === 0.U // whether fetch package is 32B aligned or not 54 val loPkgTag = btbAddr.getTag(io.in.pc.bits) 55 val hiPkgTag = loPkgTag + 1.U 56 val loMetaHits = Wire(Vec(BtbSets, Bool())) 57 val hiMetaHits = Wire(Vec(BtbSets, Bool())) 58 // val loMetaHits = meta.map{ m => (m.valid && m.tag === loPkgTag) } 59 // val hiMetaHits = meta.map{ m => (m.valid && m.tag === hiPkgTag) } 60 (0 until BtbSets).map(i => loMetaHits(i) := meta(i).valid && meta(i).tag === loPkgTag) 61 (0 until BtbSets).map(i => hiMetaHits(i) := meta(i).valid && meta(i).tag === hiPkgTag) 62 val loMetaHit = io.in.pc.valid && loMetaHits.reduce(_||_) 63 val hiMetaHit = io.in.pc.valid && hiMetaHits.reduce(_||_) && !fetchPkgAligned 64 val loMetaHitIdx = PriorityEncoder(loMetaHits.asUInt) 65 val hiMetaHitIdx = PriorityEncoder(hiMetaHits.asUInt) 66 67 (0 until BtbWayBanks).map(i => btb(i).io.r.req.valid := Mux(i.U < fetchPkgBank, hiMetaHit, loMetaHit)) 68 (0 until BtbWayBanks).map(i => btb(i).io.r.req.bits.setIdx := Mux(i.U < fetchPkgBank, hiMetaHitIdx, loMetaHitIdx)) 69 // latch pc for 1 cycle latency when reading SRAM 70 val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.valid) 71 val btbRead = Wire(Vec(FetchWidth, btbEntry())) 72 val btbHits = Wire(Vec(FetchWidth, Bool())) 73 for (i <- 0 until FetchWidth) { 74 for (j <- 0 until BtbWayBanks) { 75 when (j.U === RegEnable(fetchPkgBank, io.in.pc.valid)) { 76 val isLoPkg = i.U + j.U < BtbWayBanks.U 77 btbRead(i) := Mux(isLoPkg, btb(i+j).io.r.resp.data(0), btb(i+j-BtbWayBanks).io.r.resp.data(0)) 78 btbHits(i) := !flush && 79 Mux(isLoPkg, RegNext(loMetaHit), RegNext(hiMetaHit)) && 80 Mux(isLoPkg, RegNext(btb(i+j).io.r.req.fire(), init = false.B), RegNext(btb(i+j-BtbWayBanks).io.r.req.fire(), init = false.B)) 81 phtRead(i) := RegEnable(Mux(isLoPkg, pht(i+j).read(loMetaHitIdx), pht(i+j-BtbWayBanks).read(hiMetaHitIdx)), io.in.pc.valid) 82 } 83 } 84 } 85 val phtTaken = phtRead.map { ctr => ctr(1).asBool } 86 87 // RAS 88 def rasEntry() = new Bundle { 89 val target = UInt(VAddrBits.W) 90 val layer = UInt(3.W) // layer of nested function 91 } 92 val ras = Mem(RasSize, rasEntry()) 93 val sp = Counter(RasSize) 94 val rasRead = ras.read(sp.value) 95 val retAddr = RegEnable(rasRead.target, io.in.pc.valid) 96 97 // JBTAC 98 def jbtacEntry() = new Bundle { 99 val valid = Bool() 100 val target = UInt(VAddrBits.W) 101 } 102 val jbtacAddr = new TableAddr(log2Up(JbtacSets), JbtacBanks) 103 val jbtac = List.fill(JbtacBanks)(new SRAMTemplate(jbtacEntry(), set = JbtacSets, shouldReset = true, holdRead = true, singlePort = true)) 104 (0 until JbtacBanks).map(i => jbtac(i).io.r.req.valid := io.in.pc.valid) 105 (0 until JbtacBanks).map(i => 106 jbtac(i).io.r.req.bits.setIdx := jbtacAddr.getIdx(io.in.pc.bits) + Mux(i.U >= jbtacAddr.getWayBank(io.in.pc.bits), 0.U, 1.U) 107 ) 108 val jbtacRead = Wire(Vec(JbtacBanks, jbtacEntry())) 109 for (i <- 0 until JbtacBanks) { 110 for (j <- 0 until JbtacBanks) { 111 when (j.U === jbtacAddr.getWayBank(io.in.pc.bits)) { 112 jbtacRead(i) := Mux(j.U + i.U < JbtacBanks.U, jbtac(i+j).io.r.resp.data(0), jbtac(i+j-JbtacBanks).io.r.resp.data(0)) 113 } 114 } 115 } 116 117 // redirect based on BTB, PHT, RAS and JBTAC 118 // io.out.redirect.valid := false.B 119 // io.out.redirect.bits := DontCare 120 val redirectIdx = Wire(Vec(FetchWidth, Bool())) 121 val redirectTarget = Wire(Vec(FetchWidth, UInt(VAddrBits.W))) 122 (0 until FetchWidth).map(i => 123 redirectIdx(i) := btbHits(i) 124 && Mux(btbRead(i)._type === BTBtype.B, phtTaken(i), true.B) 125 && Mux(btbRead(i)._type === BTBtype.I, jbtacRead(i).valid, true.B) 126 ) 127 (0 until FetchWidth).map(i => 128 redirectTarget(i) := Mux(btbRead(i)._type === BTBtype.I, jbtacRead(i).target, 129 Mux(btbRead(i)._type === BTBtype.R, retAddr, btbRead(i).target)) 130 ) 131 io.out.redirect.valid := redirectIdx.asUInt.orR 132 io.out.redirect.bits := PriorityMux(redirectIdx, redirectTarget) 133}