xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision ee286e3b31ae1b5b9b93a480461aadca454c6e73)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import utils._
6import xiangshan._
7import xiangshan.backend.ALUOpType
8import xiangshan.backend.JumpOpType
9
10class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
11 def tagBits = VAddrBits - idxBits - 1
12
13 val tag = UInt(tagBits.W)
14 val idx = UInt(idxBits.W)
15 val offset = UInt(1.W)
16
17 def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
18 def getTag(x: UInt) = fromUInt(x).tag
19 def getIdx(x: UInt) = fromUInt(x).idx
20 def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
21 def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
22}
23
24class BPUReq extends XSBundle {
25  val pc = UInt(VAddrBits.W)
26  val hist = UInt(HistoryLength.W)
27  val inMask = UInt(PredictWidth.W)
28}
29
30class PredictorResponse extends XSBundle {
31  // the valid bits indicates whether a target is hit
32  val ubtb = new Bundle {
33    val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
34    val takens = Vec(PredictWidth, Bool())
35    val isRVC = Vec(PredictWidth, Bool())
36  }
37  // the valid bits indicates whether a target is hit
38  val btb = new Bundle {
39    val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
40    val isRVC = Vec(PredictWidth, Bool())
41  }
42  val bim = new Bundle {
43    val ctrs = Vec(PredictWidth, ValidUndirectioned(Bool()))
44  }
45  // the valid bits indicates whether a prediction is hit
46  val tage = new Bundle {
47    val takens = Vec(PredictWidth, ValidUndirectioned(Bool()))
48  }
49}
50
51class BPUStageIO extends XSBundle {
52  val pc = UInt(VAddrBits.W)
53  val mask = UInt(PredictWidth.W)
54  val resp = new PredictorResponse
55  val brInfo = Vec(PredictWidth, new BranchInfo)
56}
57
58
59abstract class BPUStage extends XSModule {
60  class defaultIO extends XSBundle {
61    val flush = Input(Bool())
62    val in = Flipped(Decoupled(new BPUStageIO))
63    val pred = Decoupled(new BranchPrediction)
64    val out = Decoupled(new BPUStageIO)
65  }
66  io.in.pc.ready = true.B
67  val inFire = io.in.pc.fire()
68  val inFireLatch = RegNext(inFire)
69  val inLatch = RegEnable(io.in, inFire)
70  val predLatch = RegEnable(io.pred.bits, inFireLatch)
71  val outLatch = RegEnable(io.out.bits, inFireLatch)
72
73  io.out.bits <> DontCare
74
75  val outValid = RegInit(false.B)
76  val outFire = io.out.fire()
77  when (io.flush || inFire) {
78    outValid := true.B
79  }.elsewhen (outFire) {
80    outValid := false.B
81  }
82  io.out.valid := outValid
83
84  io.pred.valid := io.out.fire()
85}
86
87class BPUStage1 extends BPUStage {
88
89  val io = new defaultIO
90
91  val btb = Module(new BTB)
92  // Use s0 pc and give prediction at s2
93  btb.io.in.pc <> io.in.pc
94  btb.io.in.mask := inLatch.inMask
95  btb.io.update := io.inOrderBrInfo
96
97  io.out.bits.resp.btb <> btb.out
98
99  val bim = Module(new BIM)
100  bim.io.in.pc <> io.in.pc
101  bim.io.in.mask := inLatch.inMask
102  bim.io.update := io.inOrderBrInfo
103
104
105  val ubtbResp = io.in.bits.ubtb
106  // the read operation is already masked, so we do not need to mask here
107  val ubtbTakens = Reverse(Cat((0 until PredictWidth).map(i => ubtbResp.targets(i).valid && ubtbResp.takens(i))))
108  val taken = ubtbTakens.orR
109  val jmpIdx = PriorityEncoder(ubtbTakens)
110  // get the last valid inst
111  val lastValidPos = PriorityMux((PredictWidth-1 to 0).map(i => (inLatch.mask(i), i.U)))
112  when (inFireLatch) {
113    io.pred.bits.redirect := taken
114    io.pred.bits.jmpIdx := jmpIdx
115    io.pred.bits.target := ubtbResp.targets(jmpIdx)
116    io.pred.bits.saveHalfRVI := ((lastValidPos === jmpIdx && taken) || !taken ) && !ubtbResp.isRVC(lastValidPos)
117  }
118  else {
119    io.pred.bits := predLatch
120  }
121
122  when (inFireLatch) {
123    io.out.bits.pc := inLatch.pc.bits
124    io.out.bits.mask := inLatch.inMask
125    io.out.bits.resp.ubtb <> ubtb.out
126    io.out.bits.resp.btb <> btb.out
127    io.out.bits.resp.bim <> bim.out
128    io.out.bits.resp.tage <> DontCare
129    io.out.bits.brInfo.foreach(_ <> DontCare)
130  }
131  else {
132    io.out.bits := outLatch
133  }
134}
135
136class BPUStage2 extends XSModule {
137  val io = IO(new Bundle() {
138    val flush = Input(Bool())
139    val in = Flipped(Decoupled(new BPUStageIO))
140    val pred = Decoupled(new BranchPrediction)
141    val out = Decoupled(new BPUStageIO)
142  })
143}
144
145class BPUStage3 extends XSModule {
146  val io = IO(new Bundle() {
147    val flush = Input(Bool())
148    val in = Flipped(Decoupled(new BPUStageIO))
149    val pred = Decoupled(new BranchPrediction)
150    val predecode = Flipped(ValidIO(new Predecode))
151  })
152}
153
154
155abstract class BasePredictor {
156  val metaLen = 0
157
158  // An implementation MUST extend the IO bundle with a response
159  // and the special input from other predictors, as well as
160  // the metas to store in BRQ
161  abstract class resp extends XSBundle {}
162  abstract class fromOthers extends XSBundle {}
163  abstract class meta extends XSBundle {}
164
165  class defaultBasePredictorIO extends XSBundle {
166    val flush = Input(Bool())
167    val pc = Flipped(ValidIO(UInt(VAddrBits.W)))
168    val hist = Input(new Bundle {
169      val bits = UInt(ExtHistoryLength.W)
170      val ptr = UInt(log2Up(ExtHistoryLength).W)
171    })
172    val inMask = Input(UInt(PredictWidth.W))
173    val update = Flipped(ValidIO(new BranchUpdateInfo))
174  }
175}
176
177trait BranchPredictorComponents extends HasXSParameter {
178  val ubtb = new Module(MicroBTB)
179  val btb = new Module(BTB)
180  val bim = new Module(BIM)
181  val tage = new Module(Tage)
182  val preds = Seq(ubtb, btb, bim, tage)
183  pred.map(_.io := DontCare)
184}
185
186abstract class BaseBPU extends XSModule with BranchPredictorComponents{
187  val io = IO(new Bundle() {
188    // from backend
189    val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
190    // from ifu, frontend redirect
191    val flush = Input(UInt(3.W))
192    // from if1
193    val in = Flipped(ValidIO(new BPUReq))
194    // to if2/if3/if4
195    val out = Vec(3, Decoupled(new BranchPrediction))
196    // from if4
197    val predecode = Flipped(ValidIO(new Predecode))
198    // to if4, some bpu info used for updating
199    val branchInfo = Decoupled(Vec(PredictWidth, new BranchInfo))
200  })
201}
202
203
204class FakeBPU extends BaseBPU {
205  io.out.foreach(i => {
206    i <> DontCare
207    i.redirect := false.B
208  })
209  io.branchInfo <> DontCare
210}
211
212class BPU extends BaseBPU {
213
214  val s1 = Module(new BPUStage1)
215  val s2 = Module(new BPUStage2)
216  val s3 = Module(new BPUStage3)
217
218  s1.io.flush := io.flush(0)
219  s2.io.flush := io.flush(1)
220  s3.io.flush := io.flush(2)
221
222  //**********************Stage 1****************************//
223  val s1_fire = io.in.pc.valid
224  val s1_resp_in = new PredictorResponse
225  val s1_brInfo_in = VecInit(0.U.asTypeOf(Vec(PredictWidth, new BranchInfo)))
226
227  s1_resp_in := DontCare
228  s1_brInfo_in := DontCare
229
230  val s1_inLatch = RegEnable(io.in, s1_fire)
231  ubtb.io.in.pc <> s1_inLatch.pc
232  ubtb.io.in.mask := s1_inLatch.inMask
233
234  // Wrap ubtb response into resp_in and brInfo_in
235  s1_resp_in.ubtb <> ubtb.io.out
236  s1_brInfo_in.ubtbWriteWay := ubtb.io.meta.writeWay
237  s1_brInfo_in.ubtbHits := VecInit(ubtb.io.out.targets.map(_.valid))
238
239  btb.io.in.pc <> io.in.pc
240  btb.io.in.mask := io.in.inMask
241
242  // Wrap btb response into resp_in and brInfo_in
243  s1_resp_in.btb <> btb.io.out
244  s1_brInfo_in.btbWriteWay := btb.io.meta.writeWay
245
246  bim.io.in.pc <> io.in.pc
247  bim.io.in.mask := io.in.inMask
248
249  // Wrap bim response into resp_in and brInfo_in
250  s1_resp_in.bim <> bim.io.out
251  s1_brInfo_in.bimCtrs := bim.io.out
252
253  tage.io.in.pc <> io.in.pc
254  tage.io.in.hist := io.in.
255
256  s1.io.in.bits.pc := io.in.pc.bits
257  s1.io.in.bits.mask := io.in.mask
258  s1.io.in.bits.resp := s1_resp_in
259  s1.io.in.bits.brInfo := s1_brInfo_in
260
261
262
263  s2.io.in <> s1.io.out
264  s3.io.in <> s2.io.out
265
266  io.out(0) <> s1.io.pred
267  io.out(1) <> s2.io.pred
268  io.out(2) <> s3.io.pred
269
270  s1.io.inOrderBrInfo <> io.inOrderBrInfo
271
272  s3.io.predecode <> io.predecode
273}
274