xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision ebd97ecb0965bb584d0fde039539d67d673e7268)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import utils._
7
8class TableAddr(val idxBits: Int, val wayBanks: Int) extends XSBundle {
9  def wayBankBits = log2Up(wayBanks)
10  def tagBits = VAddrBits - idxBits - wayBankBits - 2
11
12  val tag = UInt(tagBits.W)
13  val idx = UInt(idxBits.W)
14  val bank = UInt(wayBankBits.W)
15  val offset = UInt(2.W)
16
17  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
18  def getIdx(x: UInt) = fromUInt(x).idx
19  // def getLineBank(x: UInt) = getIdx(x)(0)
20  def getWayBank(x: UInt) = fromUInt(x).bank
21  def getTag(x: UInt) = fromUInt(x).tag
22  def getLineOffset(x: UInt) = Cat(fromUInt(x).bank, fromUInt(x).offset)
23}
24
25class BPU extends XSModule {
26  val io = IO(new Bundle() {
27    val flush = Input(Bool())
28    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
29    val out = new Bundle { val redirect = Valid(UInt(VAddrBits.W)) }
30  })
31
32  val flush = BoolStopWatch(io.flush, io.in.pc.valid, startHighPriority = true)
33
34  // BTB
35  val btbAddr = new TableAddr(log2Up(BtbSets), BtbWayBanks)
36  def btbMeta() = new Bundle {
37    val valid = Bool()
38    val tag = UInt(btbAddr.tagBits.W)
39  }
40  def btbEntry() = new Bundle {
41    val _type = UInt(2.W)
42    val target = UInt(VAddrBits.W)
43  }
44
45  val meta = RegInit(0.U.asTypeOf(Vec(BtbSets, btbMeta())))
46  val btb = List.fill(BtbWayBanks)(Module(new SRAMTemplate(btbEntry(), set = BtbSets, shouldReset = true, holdRead = true, singlePort = true)))
47
48  // PHT, which has the same complete association structure as BTB's
49  val pht = List.fill(BtbWayBanks)(Mem(BtbSets, UInt(2.W)))
50  val phtRead = Wire(Vec(FetchWidth, UInt(2.W)))
51
52  val fetchPkgBank = btbAddr.getWayBank(io.in.pc.bits)
53  val fetchPkgAligned = btbAddr.getLineOffset(io.in.pc.bits) === 0.U // whether fetch package is 32B aligned or not
54  val loPkgTag = btbAddr.getTag(io.in.pc.bits)
55  val hiPkgTag = loPkgTag + 1.U
56  val loMetaHits = Wire(Vec(BtbSets, Bool()))
57  val hiMetaHits = Wire(Vec(BtbSets, Bool()))
58  // val loMetaHits = meta.map{ m => (m.valid && m.tag === loPkgTag) }
59  // val hiMetaHits = meta.map{ m => (m.valid && m.tag === hiPkgTag) }
60  (0 until BtbSets).map(i => loMetaHits(i) := meta(i).valid && meta(i).tag === loPkgTag)
61  (0 until BtbSets).map(i => hiMetaHits(i) := meta(i).valid && meta(i).tag === hiPkgTag)
62  val loMetaHit = io.in.pc.valid && loMetaHits.reduce(_||_)
63  val hiMetaHit = io.in.pc.valid && hiMetaHits.reduce(_||_) && !fetchPkgAligned
64  val loMetaHitIdx = PriorityEncoder(loMetaHits.asUInt)
65  val hiMetaHitIdx = PriorityEncoder(hiMetaHits.asUInt)
66
67  (0 until BtbWayBanks).map(i => btb(i).io.r.req.valid := Mux(i.U < fetchPkgBank, hiMetaHit, loMetaHit))
68  (0 until BtbWayBanks).map(i => btb(i).io.r.req.bits.setIdx := Mux(i.U < fetchPkgBank, hiMetaHitIdx, loMetaHitIdx))
69  // latch pc for 1 cycle latency when reading SRAM
70  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.valid)
71  val btbRead = Wire(Vec(FetchWidth, btbEntry()))
72  val btbHits = Wire(Vec(FetchWidth, Bool()))
73  for (i <- 0 until FetchWidth) {
74    for (j <- 0 until BtbWayBanks) {
75      when (j.U === RegEnable(fetchPkgBank, io.in.pc.valid)) {
76        val isLoPkg = i.U + j.U < BtbWayBanks.U
77        btbRead(i) := Mux(isLoPkg, btb(i+j).io.r.resp.data(0), btb(i+j-BtbWayBanks).io.r.resp.data(0))
78        btbHits(i) := !flush &&
79          Mux(isLoPkg, RegNext(loMetaHit), RegNext(hiMetaHit)) &&
80          Mux(isLoPkg, RegNext(btb(i+j).io.r.req.fire(), init = false.B), RegNext(btb(i+j-BtbWayBanks).io.r.req.fire(), init = false.B))
81        phtRead(i) := RegEnable(Mux(isLoPkg, pht(i+j).read(loMetaHitIdx), pht(i+j-BtbWayBanks).read(hiMetaHitIdx)), io.in.pc.valid)
82      }
83    }
84  }
85  val phtTaken = phtRead.map { ctr => ctr(1).asBool }
86
87  // RAS
88  def rasEntry() = new Bundle {
89    val target = UInt(VAddrBits.W)
90    val layer = UInt(3.W) // layer of nested function
91  }
92  val ras = Mem(RasSize, rasEntry())
93  val sp = Counter(RasSize)
94  val rasRead = ras.read(sp.value)
95  val retAddr = RegEnable(rasRead.target, io.in.pc.valid)
96
97  // JBTAC
98  def jbtacEntry() = new Bundle {
99    val valid = Bool()
100    val target = UInt(VAddrBits.W)
101  }
102  val jbtacAddr = new TableAddr(log2Up(JbtacSets), JbtacBanks)
103  val jbtac = List.fill(JbtacBanks)(new SRAMTemplate(jbtacEntry(), set = JbtacSets, shouldReset = true, holdRead = true, singlePort = true))
104  (0 until JbtacBanks).map(i => jbtac(i).io.r.req.valid := io.in.pc.valid)
105  (0 until JbtacBanks).map(i =>
106    jbtac(i).io.r.req.bits.setIdx := jbtacAddr.getIdx(io.in.pc.bits) + Mux(i.U >= jbtacAddr.getWayBank(io.in.pc.bits), 0.U, 1.U)
107  )
108  val jbtacRead = Wire(Vec(JbtacBanks, jbtacEntry()))
109  for (i <- 0 until JbtacBanks) {
110    for (j <- 0 until JbtacBanks) {
111      when (j.U === jbtacAddr.getWayBank(io.in.pc.bits)) {
112        jbtacRead(i) := Mux(j.U + i.U < JbtacBanks.U, jbtac(i+j).io.r.resp.data(0), jbtac(i+j-JbtacBanks).io.r.resp.data(0))
113      }
114    }
115  }
116
117  // redirect based on BTB, PHT, RAS and JBTAC
118  // io.out.redirect.valid := false.B
119  // io.out.redirect.bits := DontCare
120  val redirectIdx = Wire(Vec(FetchWidth, Bool()))
121  val redirectTarget = Wire(Vec(FetchWidth, UInt(VAddrBits.W)))
122  (0 until FetchWidth).map(i =>
123    redirectIdx(i) := btbHits(i)
124      && Mux(btbRead(i)._type === BTBtype.B, phtTaken(i), true.B)
125      && Mux(btbRead(i)._type === BTBtype.I, jbtacRead(i).valid, true.B)
126  )
127  (0 until FetchWidth).map(i =>
128    redirectTarget(i) := Mux(btbRead(i)._type === BTBtype.I, jbtacRead(i).target,
129      Mux(btbRead(i)._type === BTBtype.R, retAddr, btbRead(i).target))
130  )
131  io.out.redirect.valid := redirectIdx.asUInt.orR
132  io.out.redirect.bits := PriorityMux(redirectIdx, redirectTarget)
133}