xref: /XiangShan/src/main/scala/xiangshan/frontend/BPU.scala (revision 53bf607751fc7cb144b954330aa9045e4f3987e0)
1package xiangshan.frontend
2
3import chisel3._
4import chisel3.util._
5import utils._
6import xiangshan._
7import xiangshan.backend.ALUOpType
8import xiangshan.backend.JumpOpType
9
10class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
11 def tagBits = VAddrBits - idxBits - 1
12
13 val tag = UInt(tagBits.W)
14 val idx = UInt(idxBits.W)
15 val offset = UInt(1.W)
16
17 def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
18 def getTag(x: UInt) = fromUInt(x).tag
19 def getIdx(x: UInt) = fromUInt(x).idx
20 def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
21 def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
22}
23
24class PredictorResponse extends XSBundle {
25  // the valid bits indicates whether a target is hit
26  val ubtb = new Bundle {
27    val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
28    val takens = Vec(PredictWidth, Bool())
29    val isRVC = Vec(PredictWidth, Bool())
30  }
31  // the valid bits indicates whether a target is hit
32  val btb = new Bundle {
33    val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
34    val isRVC = Vec(PredictWidth, Bool())
35  }
36  val bim = new Bundle {
37    val ctrs = Vec(PredictWidth, ValidUndirectioned(Bool()))
38  }
39  // the valid bits indicates whether a prediction is hit
40  val tage = new Bundle {
41    val takens = Vec(PredictWidth, ValidUndirectioned(Bool()))
42  }
43}
44
45class BPUStageIO extends XSBundle {
46  val pc = UInt(VAddrBits.W)
47  val mask = UInt(PredictWidth.W)
48  val resp = new PredictorResponse
49  val brInfo = Vec(PredictWidth, new BranchInfo)
50}
51
52
53abstract class BPUStage extends XSModule {
54  class defaultIO extends XSBundle {
55    val flush = Input(Bool())
56    val in = Flipped(Decoupled(new BPUStageIO))
57    val pred = Decoupled(new BranchPrediction)
58    val out = Decoupled(new BPUStageIO)
59  }
60  io.in.pc.ready = true.B
61  val inFire = io.in.pc.fire()
62  val inFireLatch = RegNext(inFire)
63  val inLatch = RegEnable(io.in, inFire)
64  val predLatch = RegEnable(io.pred.bits, inFireLatch)
65  val outLatch = RegEnable(io.out.bits, inFireLatch)
66
67  io.out.bits <> DontCare
68
69  val outValid = RegInit(false.B)
70  val outFire = io.out.fire()
71  when (io.flush || inFire) {
72    outValid := true.B
73  }.elsewhen (outFire) {
74    outValid := false.B
75  }
76  io.out.valid := outValid
77
78  io.pred.valid := io.out.fire()
79}
80
81class BPUStage1 extends BPUStage {
82
83  val io = new defaultIO
84
85  val btb = Module(new BTB)
86  // Use s0 pc and give prediction at s2
87  btb.io.in.pc <> io.in.pc
88  btb.io.in.mask := inLatch.inMask
89  btb.io.update := io.inOrderBrInfo
90
91  io.out.bits.resp.btb <> btb.out
92
93  val bim = Module(new BIM)
94  bim.io.in.pc <> io.in.pc
95  bim.io.in.mask := inLatch.inMask
96  bim.io.update := io.inOrderBrInfo
97
98
99  val ubtbResp = io.in.bits.ubtb
100  // the read operation is already masked, so we do not need to mask here
101  val ubtbTakens = Reverse(Cat((0 until PredictWidth).map(i => ubtbResp.targets(i).valid && ubtbResp.takens(i))))
102  val taken = ubtbTakens.orR
103  val jmpIdx = PriorityEncoder(ubtbTakens)
104  // get the last valid inst
105  val lastValidPos = PriorityMux((PredictWidth-1 to 0).map(i => (inLatch.mask(i), i.U)))
106  when (inFireLatch) {
107    io.pred.bits.redirect := taken
108    io.pred.bits.jmpIdx := jmpIdx
109    io.pred.bits.target := ubtbResp.targets(jmpIdx)
110    io.pred.bits.saveHalfRVI := ((lastValidPos === jmpIdx && taken) || !taken ) && !ubtbResp.isRVC(lastValidPos)
111  }
112  else {
113    io.pred.bits := predLatch
114  }
115
116  when (inFireLatch) {
117    io.out.bits.pc := inLatch.pc.bits
118    io.out.bits.mask := inLatch.inMask
119    io.out.bits.resp.ubtb <> ubtb.out
120    io.out.bits.resp.btb <> btb.out
121    io.out.bits.resp.bim <> bim.out
122    io.out.bits.resp.tage <> DontCare
123    io.out.bits.brInfo.foreach(_ <> DontCare)
124  }
125  else {
126    io.out.bits := outLatch
127  }
128}
129
130class BPUStage2 extends XSModule {
131  val io = IO(new Bundle() {
132    val flush = Input(Bool())
133    val in = Flipped(Decoupled(new BPUStageIO))
134    val pred = Decoupled(new BranchPrediction)
135    val out = Decoupled(new BPUStageIO)
136  })
137}
138
139class BPUStage3 extends XSModule {
140  val io = IO(new Bundle() {
141    val flush = Input(Bool())
142    val in = Flipped(Decoupled(new BPUStageIO))
143    val pred = Decoupled(new BranchPrediction)
144    val predecode = Flipped(ValidIO(new Predecode))
145  })
146}
147
148
149abstract class BasePredictor {
150  val metaLen = 0
151
152  // An implementation MUST extend the IO bundle with a response
153  // and the special input from other predictors, as well as
154  // the metas to store in BRQ
155  abstract class resp extends XSBundle {}
156  abstract class fromOthers extends XSBundle {}
157  abstract class meta extends XSBundle {}
158
159  class defaultBasePredictorIO extends XSBundle {
160    val flush = Input(Bool())
161    val pc = Flipped(ValidIO(UInt(VAddrBits.W)))
162    val hist = Input(new Bundle {
163      val bits = UInt(ExtHistoryLength.W)
164      val ptr = UInt(log2Up(ExtHistoryLength).W)
165    })
166    val inMask = Input(UInt(PredictWidth.W))
167    val update = Flipped(ValidIO(new BranchUpdateInfo))
168  }
169}
170
171trait BranchPredictorComponents extends HasXSParameter {
172  val ubtb = new Module(MicroBTB)
173  val btb = new Module(BTB)
174  val bim = new Module(BIM)
175  val tage = new Module(Tage)
176  val preds = Seq(ubtb, btb, bim, tage)
177  pred.map(_.io := DontCare)
178}
179
180abstract class BaseBPU extends XSModule with BranchPredictorComponents{
181  val io = IO(new Bundle() {
182    // from backend
183    val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
184    // from ifu, frontend redirect
185    val flush = Input(UInt(3.W))
186    // from if1
187    val in = new Bundle {
188      val pc = Flipped(ValidIO(UInt(VAddrBits.W)))
189      val hist = Input(UInt(ExtHistoryLength.W))
190      val inMask = Input(UInt(PredictWidth.W))
191    }
192    // to if2/if3/if4
193    val out = Vec(3, Decoupled(new BranchPrediction))
194    // from if4
195    val predecode = Flipped(ValidIO(new Predecode))
196    // to if4, some bpu info used for updating
197    val branchInfo = Decoupled(Vec(PredictWidth, new BranchInfo))
198  })
199}
200
201
202class FakeBPU extends BaseBPU {
203  io.out.foreach(i => {
204    i <> DontCare
205    i.redirect := false.B
206  })
207  io.branchInfo <> DontCare
208}
209
210class BPU extends BaseBPU {
211
212  val s1 = Module(new BPUStage1)
213  val s2 = Module(new BPUStage2)
214  val s3 = Module(new BPUStage3)
215
216  s1.io.flush := io.flush(0)
217  s2.io.flush := io.flush(1)
218  s3.io.flush := io.flush(2)
219
220  //**********************Stage 1****************************//
221  val s1_fire = io.in.pc.valid
222  val s1_resp_in = new PredictorResponse
223  val s1_brInfo_in = VecInit(0.U.asTypeOf(Vec(PredictWidth, new BranchInfo)))
224
225  s1_resp_in := DontCare
226  s1_brInfo_in := DontCare
227
228  val s1_inLatch = RegEnable(io.in, s1_fire)
229  ubtb.io.in.pc <> s1_inLatch.pc
230  ubtb.io.in.mask := s1_inLatch.inMask
231
232  // Wrap ubtb response into resp_in and brInfo_in
233  s1_resp_in.ubtb <> ubtb.io.out
234  s1_brInfo_in.ubtbWriteWay := ubtb.io.meta.writeWay
235  s1_brInfo_in.ubtbHits := VecInit(ubtb.io.out.targets.map(_.valid))
236
237  btb.io.in.pc <> io.in.pc
238  btb.io.in.mask := io.in.inMask
239
240  // Wrap btb response into resp_in and brInfo_in
241  s1_resp_in.btb <> btb.io.out
242  s1_brInfo_in.btbWriteWay := btb.io.meta.writeWay
243
244  bim.io.in.pc <> io.in.pc
245  bim.io.in.mask := io.in.inMask
246
247  // Wrap bim response into resp_in and brInfo_in
248  s1_resp_in.bim <> bim.io.out
249  s1_brInfo_in.bimCtrs := bim.io.out
250
251  tage.io.in.pc <> io.in.pc
252  tage.io.in.hist := io.in.
253
254  s1.io.in.bits.pc := io.in.pc.bits
255  s1.io.in.bits.mask := io.in.mask
256  s1.io.in.bits.resp := s1_resp_in
257  s1.io.in.bits.brInfo := s1_brInfo_in
258
259
260
261  s2.io.in <> s1.io.out
262  s3.io.in <> s2.io.out
263
264  io.out(0) <> s1.io.pred
265  io.out(1) <> s2.io.pred
266  io.out(2) <> s3.io.pred
267
268  s1.io.inOrderBrInfo <> io.inOrderBrInfo
269
270  s3.io.predecode <> io.predecode
271}
272