xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision 6fb6170440667346cfe74250418af1f9ef0c2353)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import utils._
7
8class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
9  def tagBits = VAddrBits - idxBits - 2
10
11  val tag = UInt(tagBits.W)
12  val idx = UInt(idxBits.W)
13  val offset = UInt(2.W)
14
15  def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
16  def getTag(x: UInt) = fromUInt(x).tag
17  def getIdx(x: UInt) = fromUInt(x).idx
18  def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
19  def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
20}
21
22class Stage1To2IO extends XSBundle {
23  val pc = Output(UInt(VAddrBits.W))
24  val btb = new Bundle {
25    val hits = Output(UInt(FetchWidth.W))
26    val targets = Output(Vec(FetchWidth, UInt(VAddrBits.B)))
27  }
28  val jbtac = new Bundle {
29    val hitIdx = Output(UInt(FetchWidth.W))
30    val target = Output(UInt(VAddrBits.W))
31  }
32  val tage = new Bundle {
33    val hits = Output(UInt(FetchWidth.W))
34    val takens = Output(Vec(FetchWidth, Bool()))
35  }
36  val hist = Output(UInt(HistoryLength.W))
37}
38
39class BPUStage1 extends XSModule {
40  val io = IO(new Bundle() {
41    val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
42    // from backend
43    val redirect = Flipped(ValidIO(new Redirect))
44    // from Stage3
45    val flush = Input(Bool())
46    // to ifu, quick prediction result
47    val btbOut = ValidIO(new BranchPrediction)
48    // to Stage2
49    val out = Decoupled(new Stage1To2IO)
50  })
51
52  // TODO: delete this!!!
53  io.in.pc.ready := true.B
54  io.btbOut.valid := false.B
55  io.btbOut.bits := DontCare
56  io.out.valid := false.B
57  io.out.bits := DontCare
58
59}
60
61class Stage2To3IO extends Stage1To2IO {
62}
63
64class BPUStage2 extends XSModule {
65  val io = IO(new Bundle() {
66    // flush from Stage3
67    val flush = Input(Bool())
68    val in = Flipped(Decoupled(new Stage1To2IO))
69    val out = Decoupled(new Stage2To3IO)
70  })
71
72  // TODO: delete this!!!
73  io.in.ready := false.B
74  io.out.valid := false.B
75  io.out.bits := DontCare
76
77}
78
79class BPUStage3 extends XSModule {
80  val io = IO(new Bundle() {
81    val flush = Input(Bool())
82    val in = Flipped(Decoupled(new Stage2To3IO))
83    val predecode = Flipped(ValidIO(new Predecode))
84    val out = ValidIO(new BranchPrediction)
85    // from backend
86    val redirect = Flipped(ValidIO(new Redirect)) // only need isCall here
87    // to Stage1 and Stage2
88    val flushBPU = Output(Bool())
89  })
90
91  // TODO: delete this!!!
92  io.in.ready := false.B
93  io.out.valid := false.B
94  io.out.bits := DontCare
95  io.flushBPU := false.B
96
97}
98
99class BPU extends XSModule {
100  val io = IO(new Bundle() {
101    // flush pipeline and update bpu based on redirect signals from brq
102    val redirect = Flipped(ValidIO(new Redirect))
103    val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
104    // val predMask = Output(Vec(FetchWidth, Bool()))
105    // val predTargets = Output(Vec(FetchWidth, UInt(VAddrBits.W)))
106    val btbOut = ValidIO(new BranchPrediction)
107    val tageOut = ValidIO(new BranchPrediction)
108
109    // predecode info from icache
110    // TODO: simplify this after implement predecode unit
111    val predecode = Flipped(ValidIO(new Predecode))
112  })
113
114  val s1 = Module(new BPUStage1)
115  val s2 = Module(new BPUStage2)
116  val s3 = Module(new BPUStage3)
117
118  s1.io.redirect <> io.redirect
119  s1.io.flush := s3.io.flushBPU || io.redirect.valid
120  s1.io.in.pc.valid := io.in.pc.valid
121  s1.io.in.pc.bits <> io.in.pc.bits
122  io.btbOut <> s1.io.btbOut
123
124  s1.io.out <> s2.io.in
125  s2.io.flush := s3.io.flushBPU || io.redirect.valid
126
127  s2.io.out <> s3.io.in
128  s3.io.flush := io.redirect.valid
129  s3.io.predecode <> io.predecode
130  io.tageOut <> s3.io.out
131  s3.io.redirect <> io.redirect
132
133  // TODO: delete this and put BTB and JBTAC into Stage1
134  /*
135  val flush = BoolStopWatch(io.redirect.valid, io.in.pc.valid, startHighPriority = true)
136
137  // BTB makes a quick prediction for branch and direct jump, which is
138  // 4-way set-associative, and each way is divided into 4 banks.
139  val btbAddr = new TableAddr(log2Up(BtbSets), BtbBanks)
140  def btbEntry() = new Bundle {
141    val valid = Bool()
142    // TODO: don't need full length of tag and target
143    val tag = UInt(btbAddr.tagBits.W)
144    val _type = UInt(2.W)
145    val target = UInt(VAddrBits.W)
146    val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor
147  }
148
149  val btb = List.fill(BtbBanks)(List.fill(BtbWays)(
150    Module(new SRAMTemplate(btbEntry(), set = BtbSets / BtbBanks, shouldReset = true, holdRead = true, singlePort = true))))
151
152  // val fetchPkgAligned = btbAddr.getBank(io.in.pc.bits) === 0.U
153  val HeadBank = btbAddr.getBank(io.in.pc.bits)
154  val TailBank = btbAddr.getBank(io.in.pc.bits + FetchWidth.U << 2.U - 4.U)
155  for (b <- 0 until BtbBanks) {
156    for (w <- 0 until BtbWays) {
157      btb(b)(w).reset := reset.asBool
158      btb(b)(w).io.r.req.valid := io.in.pc.valid && Mux(TailBank > HeadBank, b.U >= HeadBank && b.U <= TailBank, b.U >= TailBank || b.U <= HeadBank)
159      btb(b)(w).io.r.req.bits.setIdx := btbAddr.getBankIdx(io.in.pc.bits)
160    }
161  }
162  // latch pc for 1 cycle latency when reading SRAM
163  val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.valid)
164  val btbRead = Wire(Vec(BtbBanks, Vec(BtbWays, btbEntry())))
165  val btbHits = Wire(Vec(FetchWidth, Bool()))
166  val btbTargets = Wire(Vec(FetchWidth, UInt(VAddrBits.W)))
167  val btbTypes = Wire(Vec(FetchWidth, UInt(2.W)))
168  // val btbPreds = Wire(Vec(FetchWidth, UInt(2.W)))
169  val btbTakens = Wire(Vec(FetchWidth, Bool()))
170  for (b <- 0 until BtbBanks) {
171    for (w <- 0 until BtbWays) {
172      btbRead(b)(w) := btb(b)(w).io.r.resp.data(0)
173    }
174  }
175  for (i <- 0 until FetchWidth) {
176    btbHits(i) := false.B
177    for (b <- 0 until BtbBanks) {
178      for (w <- 0 until BtbWays) {
179        when (b.U === btbAddr.getBank(pcLatch) && btbRead(b)(w).valid && btbRead(b)(w).tag === btbAddr.getTag(Cat(pcLatch(VAddrBits - 1, 2), 0.U(2.W)) + i.U << 2)) {
180          btbHits(i) := !flush && RegNext(btb(b)(w).io.r.req.fire(), init = false.B)
181          btbTargets(i) := btbRead(b)(w).target
182          btbTypes(i) := btbRead(b)(w)._type
183          // btbPreds(i) := btbRead(b)(w).pred
184          btbTakens(i) := (btbRead(b)(w).pred)(1).asBool
185        }.otherwise {
186          btbHits(i) := false.B
187          btbTargets(i) := DontCare
188          btbTypes(i) := DontCare
189          btbTakens(i) := DontCare
190        }
191      }
192    }
193  }
194
195  // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret.
196  val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks)
197  def jbtacEntry() = new Bundle {
198    val valid = Bool()
199    // TODO: don't need full length of tag and target
200    val tag = UInt(jbtacAddr.tagBits.W)
201    val target = UInt(VAddrBits.W)
202  }
203
204  val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = true)))
205
206  (0 until JbtacBanks).map(i => jbtac(i).reset := reset.asBool)
207  (0 until JbtacBanks).map(i => jbtac(i).io.r.req.valid := io.in.pc.valid)
208  (0 until JbtacBanks).map(i => jbtac(i).io.r.req.bits.setIdx := jbtacAddr.getBankIdx(Cat((io.in.pc.bits)(VAddrBits - 1, 2), 0.U(2.W)) + i.U << 2))
209
210  val jbtacRead = Wire(Vec(JbtacBanks, jbtacEntry()))
211  (0 until JbtacBanks).map(i => jbtacRead(i) := jbtac(i).io.r.resp.data(0))
212  val jbtacHits = Wire(Vec(FetchWidth, Bool()))
213  val jbtacTargets = Wire(Vec(FetchWidth, UInt(VAddrBits.W)))
214  val jbtacHeadBank = jbtacAddr.getBank(Cat(pcLatch(VAddrBits - 1, 2), 0.U(2.W)))
215  for (i <- 0 until FetchWidth) {
216    jbtacHits(i) := false.B
217    for (b <- 0 until JbtacBanks) {
218      when (jbtacHeadBank + i.U === b.U) {
219        jbtacHits(i) := jbtacRead(b).valid && jbtacRead(b).tag === jbtacAddr.getTag(Cat(pcLatch(VAddrBits - 1, 2), 0.U(2.W)) + i.U << 2) &&
220          !flush && RegNext(jbtac(b).io.r.req.fire(), init = false.B)
221        jbtacTargets(i) := jbtacRead(b).target
222      }.otherwise {
223        jbtacHits(i) := false.B
224        jbtacTargets(i) := DontCare
225      }
226    }
227  }
228
229  // redirect based on BTB and JBTAC
230  (0 until FetchWidth).map(i => io.predMask(i) := btbHits(i) && Mux(btbTypes(i) === BTBtype.B, btbTakens(i), true.B) || jbtacHits(i))
231  (0 until FetchWidth).map(i => io.predTargets(i) := Mux(btbHits(i) && !(btbTypes(i) === BTBtype.B && !btbTakens(i)), btbTargets(i), jbtacTargets(i)))
232
233
234  // update bpu, including BTB, JBTAC...
235  // 1. update BTB
236  // 1.1 read the selected bank
237  for (b <- 0 until BtbBanks) {
238    for (w <- 0 until BtbWays) {
239      btb(b)(w).io.r.req.valid := io.redirect.valid && btbAddr.getBank(io.redirect.bits.pc) === b.U
240      btb(b)(w).io.r.req.bits.setIdx := btbAddr.getBankIdx(io.redirect.bits.pc)
241    }
242  }
243
244  // 1.2 match redirect pc tag with the 4 tags in a btb line, find a way to write
245  // val redirectLatch = RegEnable(io.redirect.bits, io.redirect.valid)
246  val redirectLatch = RegNext(io.redirect.bits, init = 0.U.asTypeOf(new Redirect))
247  val bankLatch = btbAddr.getBank(redirectLatch.pc)
248  val btbUpdateRead = Wire(Vec(BtbWays, btbEntry()))
249  val btbValids = Wire(Vec(BtbWays, Bool()))
250  val btbUpdateTagHits = Wire(Vec(BtbWays, Bool()))
251  for (b <- 0 until BtbBanks) {
252    for (w <- 0 until BtbWays) {
253      when (b.U === bankLatch) {
254        btbUpdateRead(w) := btb(b)(w).io.r.resp.data(0)
255        btbValids(w) := btbUpdateRead(w).valid && RegNext(btb(b)(w).io.r.req.fire(), init = false.B)
256      }.otherwise {
257        btbUpdateRead(w) := 0.U.asTypeOf(btbEntry())
258        btbValids(w) := false.B
259      }
260    }
261  }
262  (0 until BtbWays).map(w => btbUpdateTagHits(w) := btbValids(w) && btbUpdateRead(w).tag === btbAddr.getTag(redirectLatch.pc))
263  // val btbWriteWay = Wire(Vec(BtbWays, Bool()))
264  val btbWriteWay = Wire(UInt(BtbWays.W))
265  val btbInvalids = ~ btbValids.asUInt
266  when (btbUpdateTagHits.asUInt.orR) {
267    // tag hits
268    btbWriteWay := btbUpdateTagHits.asUInt
269  }.elsewhen (!btbValids.asUInt.andR) {
270    // no tag hits but there are free entries
271    btbWriteWay := Mux(btbInvalids >= 8.U, "b1000".U,
272      Mux(btbInvalids >= 4.U, "b0100".U,
273      Mux(btbInvalids >= 2.U, "b0010".U, "b0001".U)))
274  }.otherwise {
275    // no tag hits and no free entry, select a victim way
276    btbWriteWay := UIntToOH(LFSR64()(log2Up(BtbWays) - 1, 0))
277  }
278
279  // 1.3 calculate new 2-bit counter value
280  val btbWrite = WireInit(0.U.asTypeOf(btbEntry()))
281  btbWrite.valid := true.B
282  btbWrite.tag := btbAddr.getTag(redirectLatch.pc)
283  btbWrite._type := redirectLatch._type
284  btbWrite.target := redirectLatch.brTarget
285  val oldPred = WireInit("b01".U)
286  oldPred := PriorityMux(btbWriteWay.asTypeOf(Vec(BtbWays, Bool())), btbUpdateRead.map{ e => e.pred })
287  val newPred = Mux(redirectLatch.taken, Mux(oldPred === "b11".U, "b11".U, oldPred + 1.U),
288    Mux(oldPred === "b00".U, "b00".U, oldPred - 1.U))
289  btbWrite.pred := Mux(btbUpdateTagHits.asUInt.orR && redirectLatch._type === BTBtype.B, newPred, "b01".U)
290
291  // 1.4 write BTB
292  for (b <- 0 until BtbBanks) {
293    for (w <- 0 until BtbWays) {
294      when (b.U === bankLatch) {
295        btb(b)(w).io.w.req.valid := OHToUInt(btbWriteWay) === w.U &&
296          RegNext(io.redirect.valid, init = false.B) &&
297          (redirectLatch._type === BTBtype.B || redirectLatch._type === BTBtype.J)
298        btb(b)(w).io.w.req.bits.setIdx := btbAddr.getBankIdx(redirectLatch.pc)
299        btb(b)(w).io.w.req.bits.data := btbWrite
300      }.otherwise {
301        btb(b)(w).io.w.req.valid := false.B
302        btb(b)(w).io.w.req.bits.setIdx := DontCare
303        btb(b)(w).io.w.req.bits.data := DontCare
304      }
305    }
306  }
307
308  // 2. update JBTAC
309  val jbtacWrite = WireInit(0.U.asTypeOf(jbtacEntry()))
310  jbtacWrite.valid := true.B
311  jbtacWrite.tag := jbtacAddr.getTag(io.redirect.bits.pc)
312  jbtacWrite.target := io.redirect.bits.target
313  (0 until JbtacBanks).map(b =>
314    jbtac(b).io.w.req.valid := io.redirect.valid &&
315      b.U === jbtacAddr.getBank(io.redirect.bits.pc) &&
316      io.redirect.bits._type === BTBtype.I)
317  (0 until JbtacBanks).map(b => jbtac(b).io.w.req.bits.setIdx := jbtacAddr.getBankIdx(io.redirect.bits.pc))
318  (0 until JbtacBanks).map(b => jbtac(b).io.w.req.bits.data := jbtacWrite)
319  */
320}
321