1package xiangshan.frontend 2 3import chisel3._ 4import chisel3.util._ 5import xiangshan._ 6import xiangshan.backend.ALUOpType 7import utils._ 8 9class RAS extends BasePredictor 10{ 11 class RASResp extends Resp 12 { 13 val target =UInt(VAddrBits.W) 14 } 15 16 class RASBranchInfo extends Meta 17 { 18 val rasSp = UInt(log2Up(RasSize).W) 19 val rasTopCtr = UInt(8.W) 20 val rasToqAddr = UInt(VAddrBits.W) 21 } 22 23 class RASIO extends DefaultBasePredictorIO 24 { 25 val is_ret = Input(Bool()) 26 val callIdx = Flipped(ValidIO(UInt(log2Ceil(PredictWidth).W))) 27 val isRVC = Input(Bool()) 28 val recover = Flipped(ValidIO(new BranchUpdateInfo)) 29 val out = ValidIO(new RASResp) 30 val branchInfo = Output(new RASBranchInfo) 31 } 32 33 class RASEntry() extends XSBundle { 34 val retAddr = UInt(VAddrBits.W) 35 val ctr = UInt(8.W) // layer of nested call functions 36 } 37 38 def rasEntry() = new RASEntry 39 40 object RASEntry { 41 def apply(retAddr: UInt, ctr: UInt): RASEntry = { 42 val e = Wire(rasEntry()) 43 e.retAddr := retAddr 44 e.ctr := ctr 45 e 46 } 47 } 48 49 override val io = IO(new RASIO) 50 51 class RASStack(val rasSize: Int) extends XSModule { 52 val io = IO(new Bundle { 53 val push_valid = Input(Bool()) 54 val pop_valid = Input(Bool()) 55 val new_addr = Input(UInt(VAddrBits.W)) 56 val top_addr = Output(UInt(VAddrBits.W)) 57 val is_empty = Output(Bool()) 58 val is_full = Output(Bool()) 59 val copy_valid = Input(Bool()) 60 val copy_in_mem = Input(Vec(rasSize, rasEntry())) 61 val copy_in_sp = Input(UInt(log2Up(rasSize).W)) 62 val copy_out_mem = Output(Vec(rasSize, rasEntry())) 63 val copy_out_sp = Output(UInt(log2Up(rasSize).W)) 64 }) 65 66 class Stack(val size: Int) extends XSModule { 67 val io = IO(new Bundle { 68 val rIdx = Input(UInt(log2Up(size).W)) 69 val rdata = Output(rasEntry()) 70 val wen = Input(Bool()) 71 val wIdx = Input(UInt(log2Up(size).W)) 72 val wdata = Input(rasEntry()) 73 val copyen = Input(Bool()) 74 val copy_in = Input(Vec(size, rasEntry())) 75 val copy_out = Output(Vec(size, rasEntry())) 76 }) 77 val mem = Reg(Vec(size, rasEntry())) 78 when (io.wen) { 79 mem(io.wIdx) := io.wdata 80 } 81 io.rdata := mem(io.rIdx) 82 (0 until size).foreach { i => io.copy_out(i) := mem(i) } 83 when (io.copyen) { 84 (0 until size).foreach {i => mem(i) := io.copy_in(i) } 85 } 86 } 87 val sp = RegInit(0.U(log2Up(rasSize).W)) 88 val stack = Module(new Stack(rasSize)).io 89 90 stack.rIdx := sp - 1.U 91 val top_entry = stack.rdata 92 val top_addr = top_entry.retAddr 93 val top_ctr = top_entry.ctr 94 val alloc_new = io.new_addr =/= top_addr 95 stack.wen := io.push_valid || io.pop_valid && top_ctr =/= 1.U 96 stack.wIdx := Mux(io.pop_valid && top_ctr =/= 1.U, sp - 1.U, Mux(alloc_new, sp, sp - 1.U)) 97 stack.wdata := Mux(io.pop_valid && top_ctr =/= 1.U, 98 RASEntry(top_addr, top_ctr - 1.U), 99 Mux(alloc_new, RASEntry(io.new_addr, 1.U), RASEntry(top_addr, top_ctr + 1.U))) 100 101 when (io.push_valid && alloc_new) { 102 sp := sp + 1.U 103 } 104 105 when (io.pop_valid && top_ctr === 1.U) { 106 sp := Mux(sp === 0.U, 0.U, sp - 1.U) 107 } 108 109 io.copy_out_mem := stack.copy_out 110 io.copy_out_sp := sp 111 stack.copyen := io.copy_valid 112 stack.copy_in := io.copy_in_mem 113 when (io.copy_valid) { 114 sp := io.copy_in_sp 115 } 116 117 io.top_addr := top_addr 118 io.is_empty := sp === 0.U 119 io.is_full := sp === (RasSize - 1).U 120 } 121 122 // val ras_0 = Reg(Vec(RasSize, rasEntry())) //RegInit(0.U)asTypeOf(Vec(RasSize,rasEntry)) cause comb loop 123 // val ras_1 = Reg(Vec(RasSize, rasEntry())) 124 // val sp_0 = RegInit(0.U(log2Up(RasSize).W)) 125 // val sp_1 = RegInit(0.U(log2Up(RasSize).W)) 126 // val choose_bit = RegInit(false.B) //start with 0 127 // val spec_ras = Mux(choose_bit, ras_1, ras_0) 128 // val spec_sp = Mux(choose_bit,sp_1,sp_0) 129 // val commit_ras = Mux(choose_bit, ras_0, ras_1) 130 // val commit_sp = Mux(choose_bit,sp_0,sp_1) 131 132 // val spec_ras = Reg(Vec(RasSize, rasEntry())) 133 // val spec_sp = RegInit(0.U(log2Up(RasSize).W)) 134 // val commit_ras = Reg(Vec(RasSize, rasEntry())) 135 // val commit_sp = RegInit(0.U(log2Up(RasSize).W)) 136 137 val spec_ras = Module(new RASStack(RasSize)).io 138 139 val spec_push = WireInit(false.B) 140 val spec_pop = WireInit(false.B) 141 val spec_new_addr = WireInit(io.pc.bits + (io.callIdx.bits << 1.U) + Mux(io.isRVC,2.U,4.U)) 142 spec_ras.push_valid := spec_push 143 spec_ras.pop_valid := spec_pop 144 spec_ras.new_addr := spec_new_addr 145 val spec_is_empty = spec_ras.is_empty 146 val spec_is_full = spec_ras.is_full 147 val spec_top_addr = spec_ras.top_addr 148 149 spec_push := !spec_is_full && io.callIdx.valid && io.pc.valid 150 spec_pop := !spec_is_empty && io.is_ret && io.pc.valid 151 152 val commit_ras = Module(new RASStack(RasSize)).io 153 154 val commit_push = WireInit(false.B) 155 val commit_pop = WireInit(false.B) 156 val commit_new_addr = Mux(io.recover.bits.pd.isRVC,io.recover.bits.pc + 2.U,io.recover.bits.pc + 4.U) 157 commit_ras.push_valid := commit_push 158 commit_ras.pop_valid := commit_pop 159 commit_ras.new_addr := commit_new_addr 160 val commit_is_empty = commit_ras.is_empty 161 val commit_is_full = commit_ras.is_full 162 val commit_top_addr = commit_ras.top_addr 163 164 commit_push := !commit_is_full && io.recover.valid && io.recover.bits.pd.isCall 165 commit_pop := !commit_is_empty && io.recover.valid && io.recover.bits.pd.isRet 166 167 168 io.out.valid := !spec_is_empty && io.is_ret 169 io.out.bits.target := spec_top_addr 170 // TODO: back-up stack for ras 171 // use checkpoint to recover RAS 172 173 val copy_valid = io.recover.valid && io.recover.bits.isMisPred 174 val copy_next = RegNext(copy_valid) 175 spec_ras.copy_valid := copy_next 176 spec_ras.copy_in_mem := commit_ras.copy_out_mem 177 spec_ras.copy_in_sp := commit_ras.copy_out_sp 178 commit_ras.copy_valid := DontCare 179 commit_ras.copy_in_mem := DontCare 180 commit_ras.copy_in_sp := DontCare 181 182 //no need to pass the ras branchInfo 183 io.branchInfo.rasSp := DontCare 184 io.branchInfo.rasTopCtr := DontCare 185 io.branchInfo.rasToqAddr := DontCare 186 187 if (BPUDebug && debug) { 188 // XSDebug("----------------RAS(spec)----------------\n") 189 // XSDebug(" index addr ctr \n") 190 // for(i <- 0 until RasSize){ 191 // XSDebug(" (%d) 0x%x %d",i.U,spec_ras(i).retAddr,spec_ras(i).ctr) 192 // when(i.U === spec_sp){XSDebug(false,true.B," <----sp")} 193 // XSDebug(false,true.B,"\n") 194 // } 195 // XSDebug("----------------RAS(commit)----------------\n") 196 // XSDebug(" index addr ctr \n") 197 // for(i <- 0 until RasSize){ 198 // XSDebug(" (%d) 0x%x %d",i.U,commit_ras(i).retAddr,commit_ras(i).ctr) 199 // when(i.U === commit_sp){XSDebug(false,true.B," <----sp")} 200 // XSDebug(false,true.B,"\n") 201 // } 202 203 // XSDebug(spec_push, "(spec_ras)push inAddr: 0x%x inCtr: %d | allocNewEntry:%d | sp:%d \n",spec_ras_write.retAddr,spec_ras_write.ctr,sepc_alloc_new,spec_sp.asUInt) 204 // XSDebug(spec_pop, "(spec_ras)pop outValid:%d outAddr: 0x%x \n",io.out.valid,io.out.bits.target) 205 // XSDebug(commit_push, "(commit_ras)push inAddr: 0x%x inCtr: %d | allocNewEntry:%d | sp:%d \n",commit_ras_write.retAddr,commit_ras_write.ctr,sepc_alloc_new,commit_sp.asUInt) 206 // XSDebug(commit_pop, "(commit_ras)pop outValid:%d outAddr: 0x%x \n",io.out.valid,io.out.bits.target) 207 // XSDebug("copyValid:%d copyNext:%d \n",copy_valid,copy_next) 208 } 209 210 211 // val recoverSp = io.recover.bits.brInfo.rasSp 212 // val recoverCtr = io.recover.bits.brInfo.rasTopCtr 213 // val recoverAddr = io.recover.bits.brInfo.rasToqAddr 214 // val recover_top = ras(recoverSp - 1.U) 215 // when (recover_valid) { 216 // sp := recoverSp 217 // recover_top.ctr := recoverCtr 218 // recover_top.retAddr := recoverAddr 219 // XSDebug("RAS update: SP:%d , Ctr:%d \n",recoverSp,recoverCtr) 220 // } 221 // val recover_and_push = recover_valid && push 222 // val recover_and_pop = recover_valid && pop 223 // val recover_alloc_new = new_addr =/= recoverAddr 224 // when(recover_and_push) 225 // { 226 // when(recover_alloc_new){ 227 // sp := recoverSp + 1.U 228 // ras(recoverSp).retAddr := new_addr 229 // ras(recoverSp).ctr := 1.U 230 // recover_top.retAddr := recoverAddr 231 // recover_top.ctr := recoverCtr 232 // } .otherwise{ 233 // sp := recoverSp 234 // recover_top.ctr := recoverCtr + 1.U 235 // recover_top.retAddr := recoverAddr 236 // } 237 // } .elsewhen(recover_and_pop) 238 // { 239 // io.out.bits.target := recoverAddr 240 // when ( recover_top.ctr === 1.U) { 241 // sp := recoverSp - 1.U 242 // }.otherwise { 243 // sp := recoverSp 244 // recover_top.ctr := recoverCtr - 1.U 245 // } 246 // } 247 248} 249