1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import xiangshan._ 6import xiangshan.backend.ALUOpType 7import utils._ 8import chisel3.experimental.chiselName 9 10class RASEntry() extends XSBundle { 11 val retAddr = UInt(VAddrBits.W) 12 val ctr = UInt(8.W) // layer of nested call functions 13} 14 15@chiselName 16class RAS extends BasePredictor 17{ 18 class RASResp extends Resp 19 { 20 val target =UInt(VAddrBits.W) 21 } 22 23 class RASBranchInfo extends Meta 24 { 25 val rasSp = UInt(log2Up(RasSize).W) 26 val rasTopCtr = UInt(8.W) 27 val rasToqAddr = UInt(VAddrBits.W) 28 } 29 30 class RASIO extends DefaultBasePredictorIO 31 { 32 val is_ret = Input(Bool()) 33 val callIdx = Flipped(ValidIO(UInt(log2Ceil(PredictWidth).W))) 34 val isRVC = Input(Bool()) 35 val isLastHalfRVI = Input(Bool()) 36 val recover = Flipped(ValidIO(new CfiUpdateInfo)) 37 val out = ValidIO(new RASResp) 38 val meta = Output(new RASBranchInfo) 39 } 40 41 42 def rasEntry() = new RASEntry 43 44 object RASEntry { 45 def apply(retAddr: UInt, ctr: UInt): RASEntry = { 46 val e = Wire(rasEntry()) 47 e.retAddr := retAddr 48 e.ctr := ctr 49 e 50 } 51 } 52 53 override val io = IO(new RASIO) 54 override val debug = true 55 56 @chiselName 57 class RASStack(val rasSize: Int) extends XSModule { 58 val io = IO(new Bundle { 59 val push_valid = Input(Bool()) 60 val pop_valid = Input(Bool()) 61 val new_addr = Input(UInt(VAddrBits.W)) 62 val top_addr = Output(UInt(VAddrBits.W)) 63 val is_empty = Output(Bool()) 64 val is_full = Output(Bool()) 65 val copy_valid = Input(Bool()) 66 val copy_in_mem = Input(Vec(rasSize, rasEntry())) 67 val copy_in_sp = Input(UInt(log2Up(rasSize).W)) 68 val copy_in_top = Input(rasEntry()) 69 val copy_out_mem = Output(Vec(rasSize, rasEntry())) 70 val copy_out_sp = Output(UInt(log2Up(rasSize).W)) 71 val copy_out_top = Output(rasEntry()) 72 73 }) 74 val debugIO = IO(new Bundle{ 75 val write_entry = Output(rasEntry()) 76 val alloc_new = Output(Bool()) 77 val sp = Output(UInt(log2Up(rasSize).W)) 78 val topRegister = Output(rasEntry()) 79 }) 80 @chiselName 81 class Stack(val size: Int) extends XSModule { 82 val io = IO(new Bundle { 83 val rIdx = Input(UInt(log2Up(size).W)) 84 val rdata = Output(rasEntry()) 85 val wen = Input(Bool()) 86 val wIdx = Input(UInt(log2Up(size).W)) 87 val wdata = Input(rasEntry()) 88 val copyen = Input(Bool()) 89 val copy_in = Input(Vec(size, rasEntry())) 90 val copy_out = Output(Vec(size, rasEntry())) 91 }) 92 val mem = Reg(Vec(size, rasEntry())) 93 when (io.wen) { 94 mem(io.wIdx) := io.wdata 95 } 96 io.rdata := mem(io.rIdx) 97 (0 until size).foreach { i => io.copy_out(i) := mem(i) } 98 when (io.copyen) { 99 (0 until size).foreach {i => mem(i) := io.copy_in(i) } 100 } 101 } 102 val sp = RegInit(RasSize.U((log2Up(rasSize) + 1).W)) 103 val topRegister = RegInit(0.U.asTypeOf(new RASEntry)) 104 val stack = Module(new Stack(rasSize)).io 105 106 stack.rIdx := sp - 1.U 107 val top_addr = topRegister.retAddr 108 val top_ctr = topRegister.ctr 109 val alloc_new = io.new_addr =/= top_addr 110 // stack.wen := io.push_valid || io.pop_valid && top_ctr =/= 1.U 111 // stack.wIdx := Mux(io.pop_valid && top_ctr =/= 1.U, sp - 1.U, Mux(alloc_new, sp, sp - 1.U)) 112 // val write_addr = Mux(io.pop_valid && top_ctr =/= 1.U, top_addr, io.new_addr) 113 // val write_ctr = Mux(io.pop_valid && top_ctr =/= 1.U, top_ctr - 1.U, Mux(alloc_new, 1.U, top_ctr + 1.U)) 114 115 stack.wen := io.push_valid && !io.is_empty 116 stack.wIdx := sp 117 val write_addr = topRegister.retAddr 118 val write_ctr = topRegister.ctr 119 120 val write_entry = RASEntry(write_addr, write_ctr) 121 stack.wdata := write_entry 122 debugIO.write_entry := write_entry 123 debugIO.alloc_new := alloc_new 124 debugIO.sp := sp 125 debugIO.topRegister := topRegister 126 127 val is_empty = sp === RasSize.U 128 val is_full = sp === (RasSize - 1).U 129 130 when (io.push_valid && alloc_new) { 131 sp := Mux(is_full, sp, Mux(is_empty, 0.U,sp + 1.U)) 132 top_addr := io.new_addr 133 top_ctr := 1.U 134 } .elsewhen(io.push_valid) { 135 top_ctr := top_ctr + 1.U 136 } 137 138 when (io.pop_valid && top_ctr === 1.U) { 139 sp := Mux(is_empty, sp ,Mux(sp === 0.U, RasSize.U,sp - 1.U)) 140 top_addr := stack.rdata.retAddr 141 top_ctr := stack.rdata.ctr 142 } .elsewhen(io.pop_valid) { 143 top_ctr := top_ctr - 1.U 144 } 145 146 io.copy_out_mem := stack.copy_out 147 io.copy_out_sp := sp 148 io.copy_out_top := topRegister 149 stack.copyen := io.copy_valid 150 stack.copy_in := io.copy_in_mem 151 when (io.copy_valid) { 152 sp := io.copy_in_sp 153 topRegister := io.copy_in_top 154 } 155 156 io.top_addr := top_addr 157 io.is_empty := is_empty 158 io.is_full := is_full 159 } 160 161 // val ras_0 = Reg(Vec(RasSize, rasEntry())) //RegInit(0.U)asTypeOf(Vec(RasSize,rasEntry)) cause comb loop 162 // val ras_1 = Reg(Vec(RasSize, rasEntry())) 163 // val sp_0 = RegInit(0.U(log2Up(RasSize).W)) 164 // val sp_1 = RegInit(0.U(log2Up(RasSize).W)) 165 // val choose_bit = RegInit(false.B) //start with 0 166 // val spec_ras = Mux(choose_bit, ras_1, ras_0) 167 // val spec_sp = Mux(choose_bit,sp_1,sp_0) 168 // val commit_ras = Mux(choose_bit, ras_0, ras_1) 169 // val commit_sp = Mux(choose_bit,sp_0,sp_1) 170 171 // val spec_ras = Reg(Vec(RasSize, rasEntry())) 172 // val spec_sp = RegInit(0.U(log2Up(RasSize).W)) 173 // val commit_ras = Reg(Vec(RasSize, rasEntry())) 174 // val commit_sp = RegInit(0.U(log2Up(RasSize).W)) 175 176 val spec = Module(new RASStack(RasSize)) 177 val spec_ras = spec.io 178 179 180 val spec_push = WireInit(false.B) 181 val spec_pop = WireInit(false.B) 182 val spec_new_addr = packetAligned(io.pc.bits) + (io.callIdx.bits << instOffsetBits.U) + Mux( (io.isRVC | io.isLastHalfRVI) && HasCExtension.B, 2.U, 4.U) 183 spec_ras.push_valid := spec_push 184 spec_ras.pop_valid := spec_pop 185 spec_ras.new_addr := spec_new_addr 186 val spec_is_empty = spec_ras.is_empty 187 val spec_is_full = spec_ras.is_full 188 val spec_top_addr = spec_ras.top_addr 189 190 spec_push := !spec_is_full && io.callIdx.valid && io.pc.valid 191 spec_pop := !spec_is_empty && io.is_ret && io.pc.valid 192 193 val commit = Module(new RASStack(RasSize)) 194 val commit_ras = commit.io 195 196 val commit_push = WireInit(false.B) 197 val commit_pop = WireInit(false.B) 198 val commit_new_addr = Mux(io.recover.bits.pd.isRVC && HasCExtension.B, io.recover.bits.pc + 2.U, io.recover.bits.pc + 4.U) 199 commit_ras.push_valid := commit_push 200 commit_ras.pop_valid := commit_pop 201 commit_ras.new_addr := commit_new_addr 202 val commit_is_empty = commit_ras.is_empty 203 val commit_is_full = commit_ras.is_full 204 val commit_top_addr = commit_ras.top_addr 205 206 commit_push := !commit_is_full && io.recover.valid && !io.recover.bits.isReplay && io.recover.bits.pd.isCall 207 commit_pop := !commit_is_empty && io.recover.valid && !io.recover.bits.isReplay && io.recover.bits.pd.isRet 208 209 210 io.out.valid := !spec_is_empty 211 io.out.bits.target := spec_top_addr 212 // TODO: back-up stack for ras 213 // use checkpoint to recover RAS 214 215 val copy_valid = io.recover.valid && (io.recover.bits.isMisPred || io.recover.bits.isReplay) 216 val copy_next = RegNext(copy_valid) 217 spec_ras.copy_valid := copy_next 218 spec_ras.copy_in_mem := commit_ras.copy_out_mem 219 spec_ras.copy_in_sp := commit_ras.copy_out_sp 220 spec_ras.copy_in_top := commit_ras.copy_out_top 221 commit_ras.copy_valid := false.B 222 commit_ras.copy_in_mem := DontCare 223 commit_ras.copy_in_sp := DontCare 224 commit_ras.copy_in_top := DontCare 225 226 //no need to pass the ras branchInfo 227 io.meta.rasSp := DontCare 228 io.meta.rasTopCtr := DontCare 229 io.meta.rasToqAddr := DontCare 230 231 if (BPUDebug && debug) { 232 val spec_debug = spec.debugIO 233 val commit_debug = commit.debugIO 234 XSDebug("----------------RAS(spec)----------------\n") 235 XSDebug(" index addr ctr \n") 236 for(i <- 0 until RasSize){ 237 XSDebug(" (%d) 0x%x %d",i.U,spec_ras.copy_out_mem(i).retAddr,spec_ras.copy_out_mem(i).ctr) 238 when(i.U === spec_ras.copy_out_sp){XSDebug(false,true.B," <----sp")} 239 XSDebug(false,true.B,"\n") 240 } 241 XSDebug("----------------RAS(commit)----------------\n") 242 XSDebug(" index addr ctr \n") 243 for(i <- 0 until RasSize){ 244 XSDebug(" (%d) 0x%x %d",i.U,commit_ras.copy_out_mem(i).retAddr,commit_ras.copy_out_mem(i).ctr) 245 when(i.U === commit_ras.copy_out_sp){XSDebug(false,true.B," <----sp")} 246 XSDebug(false,true.B,"\n") 247 } 248 249 XSDebug(spec_push, "(spec_ras)push inAddr: 0x%x inCtr: %d | allocNewEntry:%d | sp:%d | TopReg.addr %x ctr:%d\n",spec_new_addr,spec_debug.write_entry.ctr,spec_debug.alloc_new,spec_debug.sp.asUInt,spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr) 250 XSDebug(spec_pop, "(spec_ras)pop outValid:%d outAddr: 0x%x \n",io.out.valid,io.out.bits.target) 251 XSDebug(commit_push, "(commit_ras)push inAddr: 0x%x inCtr: %d | allocNewEntry:%d | sp:%d | TopReg.addr %x ctr:%d\n",commit_new_addr,commit_debug.write_entry.ctr,commit_debug.alloc_new,commit_debug.sp.asUInt,commit_debug.topRegister.retAddr,commit_debug.topRegister.ctr) 252 XSDebug(commit_pop, "(commit_ras)pop outValid:%d outAddr: 0x%x \n",io.out.valid,io.out.bits.target) 253 XSDebug("copyValid:%d copyNext:%d \n",copy_valid,copy_next) 254 } 255 256 257 // val recoverSp = io.recover.bits.brInfo.rasSp 258 // val recoverCtr = io.recover.bits.brInfo.rasTopCtr 259 // val recoverAddr = io.recover.bits.brInfo.rasToqAddr 260 // val recover_top = ras(recoverSp - 1.U) 261 // when (recover_valid) { 262 // sp := recoverSp 263 // recover_top.ctr := recoverCtr 264 // recover_top.retAddr := recoverAddr 265 // XSDebug("RAS update: SP:%d , Ctr:%d \n",recoverSp,recoverCtr) 266 // } 267 // val recover_and_push = recover_valid && push 268 // val recover_and_pop = recover_valid && pop 269 // val recover_alloc_new = new_addr =/= recoverAddr 270 // when(recover_and_push) 271 // { 272 // when(recover_alloc_new){ 273 // sp := recoverSp + 1.U 274 // ras(recoverSp).retAddr := new_addr 275 // ras(recoverSp).ctr := 1.U 276 // recover_top.retAddr := recoverAddr 277 // recover_top.ctr := recoverCtr 278 // } .otherwise{ 279 // sp := recoverSp 280 // recover_top.ctr := recoverCtr + 1.U 281 // recover_top.retAddr := recoverAddr 282 // } 283 // } .elsewhen(recover_and_pop) 284 // { 285 // io.out.bits.target := recoverAddr 286 // when ( recover_top.ctr === 1.U) { 287 // sp := recoverSp - 1.U 288 // }.otherwise { 289 // sp := recoverSp 290 // recover_top.ctr := recoverCtr - 1.U 291 // } 292 // } 293 294} 295