xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision 57bb43b5f11c3f1e89ac52f232fe73056b35d9bd)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.frontend
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import xiangshan._
25
26class RASEntry()(implicit p: Parameters) extends XSBundle {
27    val retAddr = UInt(VAddrBits.W)
28    val ctr = UInt(8.W) // layer of nested call functions
29}
30
31@chiselName
32class RAS(implicit p: Parameters) extends BasePredictor {
33  object RASEntry {
34    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
35      val e = Wire(new RASEntry)
36      e.retAddr := retAddr
37      e.ctr := ctr
38      e
39    }
40  }
41
42  @chiselName
43  class RASStack(val rasSize: Int) extends XSModule {
44    val io = IO(new Bundle {
45      val push_valid = Input(Bool())
46      val pop_valid = Input(Bool())
47      val spec_new_addr = Input(UInt(VAddrBits.W))
48
49      val recover_sp = Input(UInt(log2Up(rasSize).W))
50      val recover_top = Input(new RASEntry)
51      val recover_valid = Input(Bool())
52      val recover_push = Input(Bool())
53      val recover_pop = Input(Bool())
54      val recover_new_addr = Input(UInt(VAddrBits.W))
55
56      val sp = Output(UInt(log2Up(rasSize).W))
57      val top = Output(new RASEntry)
58    })
59
60    val debugIO = IO(new Bundle{
61        val spec_push_entry = Output(new RASEntry)
62        val spec_alloc_new = Output(Bool())
63        val recover_push_entry = Output(new RASEntry)
64        val recover_alloc_new = Output(Bool())
65        val sp = Output(UInt(log2Up(rasSize).W))
66        val topRegister = Output(new RASEntry)
67        val out_mem = Output(Vec(RasSize, new RASEntry))
68    })
69
70    val stack = Mem(RasSize, new RASEntry)
71    val sp = RegInit(0.U(log2Up(rasSize).W))
72    val top = RegInit(RASEntry(0x80000000L.U, 0.U))
73    val topPtr = RegInit(0.U(log2Up(rasSize).W))
74
75    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
76    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)
77
78    val spec_alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
79    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR
80
81    // TODO: fix overflow and underflow bugs
82    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
83                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
84                              do_top: RASEntry) = {
85      when (do_push) {
86        when (do_alloc_new) {
87          sp     := ptrInc(do_sp)
88          topPtr := do_sp
89          top.retAddr := do_new_addr
90          top.ctr := 0.U
91          stack.write(do_sp, RASEntry(do_new_addr, 0.U))
92        }.otherwise {
93          when (recover) {
94            sp := do_sp
95            topPtr := do_top_ptr
96            top.retAddr := do_top.retAddr
97          }
98          top.ctr := do_top.ctr + 1.U
99          stack.write(do_top_ptr, RASEntry(do_new_addr, do_top.ctr + 1.U))
100        }
101      }.elsewhen (do_pop) {
102        when (do_top.ctr === 0.U) {
103          sp     := ptrDec(do_sp)
104          topPtr := ptrDec(do_top_ptr)
105          top := stack.read(ptrDec(do_top_ptr))
106        }.otherwise {
107          when (recover) {
108            sp := do_sp
109            topPtr := do_top_ptr
110            top.retAddr := do_top.retAddr
111          }
112          top.ctr := do_top.ctr - 1.U
113          stack.write(do_top_ptr, RASEntry(do_top.retAddr, do_top.ctr - 1.U))
114        }
115      }.otherwise {
116        when (recover) {
117          sp := do_sp
118          topPtr := do_top_ptr
119          top := do_top
120          stack.write(do_top_ptr, do_top)
121        }
122      }
123    }
124
125
126    update(io.recover_valid)(
127      Mux(io.recover_valid, io.recover_push,     io.push_valid),
128      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
129      Mux(io.recover_valid, recover_alloc_new,   spec_alloc_new),
130      Mux(io.recover_valid, io.recover_sp,       sp),
131      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
132      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
133      Mux(io.recover_valid, io.recover_top,      top))
134
135    io.sp := sp
136    io.top := top
137
138    val resetIdx = RegInit(0.U(log2Ceil(RasSize).W))
139    val do_reset = RegInit(true.B)
140    when (do_reset) {
141      stack.write(resetIdx, RASEntry(0x80000000L.U, 0.U))
142    }
143    resetIdx := resetIdx + do_reset
144    when (resetIdx === (RasSize-1).U) {
145      do_reset := false.B
146    }
147
148    debugIO.spec_push_entry := RASEntry(io.spec_new_addr, Mux(spec_alloc_new, 1.U, top.ctr + 1.U))
149    debugIO.spec_alloc_new := spec_alloc_new
150    debugIO.recover_push_entry := RASEntry(io.recover_new_addr, Mux(recover_alloc_new, 1.U, io.recover_top.ctr + 1.U))
151    debugIO.recover_alloc_new := recover_alloc_new
152    debugIO.sp := sp
153    debugIO.topRegister := top
154    for (i <- 0 until RasSize) {
155        debugIO.out_mem(i) := stack.read(i.U)
156    }
157  }
158
159  val spec = Module(new RASStack(RasSize))
160  val spec_ras = spec.io
161  val spec_top_addr = spec_ras.top.retAddr
162
163
164  val s2_spec_push = WireInit(false.B)
165  val s2_spec_pop = WireInit(false.B)
166  val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred
167  // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed
168  val s2_spec_new_addr = s2_full_pred.fallThroughAddr + Mux(s2_full_pred.last_may_be_rvi_call, 2.U, 0.U)
169  spec_ras.push_valid := s2_spec_push
170  spec_ras.pop_valid  := s2_spec_pop
171  spec_ras.spec_new_addr := s2_spec_new_addr
172
173  // confirm that the call/ret is the taken cfi
174  s2_spec_push := io.s2_fire && s2_full_pred.hit_taken_on_call && !io.s3_redirect
175  s2_spec_pop  := io.s2_fire && s2_full_pred.hit_taken_on_ret  && !io.s3_redirect
176
177  val s2_jalr_target = io.out.resp.s2.full_pred.jalr_target
178  val s2_last_target_in = s2_full_pred.targets.last
179  val s2_last_target_out = io.out.resp.s2.full_pred.targets.last
180  val s2_is_jalr = s2_full_pred.is_jalr
181  val s2_is_ret = s2_full_pred.is_ret
182  // assert(is_jalr && is_ret || !is_ret)
183  when(s2_is_ret && io.ctrl.ras_enable) {
184    s2_jalr_target := spec_top_addr
185    // FIXME: should use s1 globally
186  }
187  s2_last_target_out := Mux(s2_is_jalr, s2_jalr_target, s2_last_target_in)
188
189  val s3_top = RegEnable(spec_ras.top, io.s2_fire)
190  val s3_sp = RegEnable(spec_ras.sp, io.s2_fire)
191  val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire)
192
193  val s3_jalr_target = io.out.resp.s3.full_pred.jalr_target
194  val s3_last_target_in = io.in.bits.resp_in(0).s3.full_pred.targets.last
195  val s3_last_target_out = io.out.resp.s3.full_pred.targets.last
196  val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred.is_jalr
197  val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred.is_ret
198  // assert(is_jalr && is_ret || !is_ret)
199  when(s3_is_ret && io.ctrl.ras_enable) {
200    s3_jalr_target := s3_top.retAddr
201    // FIXME: should use s1 globally
202  }
203  s3_last_target_out := Mux(s3_is_jalr, s3_jalr_target, s3_last_target_in)
204
205  val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire)
206  val s3_popped_in_s2 = RegEnable(s2_spec_pop,  io.s2_fire)
207  val s3_push = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_call
208  val s3_pop  = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_ret
209
210  val s3_recover = io.s3_fire && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop)
211  io.out.resp.s3.rasSp  := s3_sp
212  io.out.resp.s3.rasTop := s3_top
213
214
215  val redirect = RegNext(io.redirect)
216  val do_recover = redirect.valid || s3_recover
217  val recover_cfi = redirect.bits.cfiUpdate
218
219  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
220  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
221  // when we mispredict a call, we must redo a push operation
222  // similarly, when we mispredict a return, we should redo a pop
223  spec_ras.recover_valid := do_recover
224  spec_ras.recover_push := Mux(redirect.valid, callMissPred, s3_push)
225  spec_ras.recover_pop  := Mux(redirect.valid, retMissPred, s3_pop)
226
227  spec_ras.recover_sp  := Mux(redirect.valid, recover_cfi.rasSp, s3_sp)
228  spec_ras.recover_top := Mux(redirect.valid, recover_cfi.rasEntry, s3_top)
229  spec_ras.recover_new_addr := Mux(redirect.valid, recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U), s3_spec_new_addr)
230
231
232  XSPerfAccumulate("ras_s3_recover", s3_recover)
233  XSPerfAccumulate("ras_redirect_recover", redirect.valid)
234  XSPerfAccumulate("ras_s3_and_redirect_recover_at_the_same_time", s3_recover && redirect.valid)
235  // TODO: back-up stack for ras
236  // use checkpoint to recover RAS
237
238  val spec_debug = spec.debugIO
239  XSDebug("----------------RAS----------------\n")
240  XSDebug(" TopRegister: 0x%x   %d \n",spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr)
241  XSDebug("  index       addr           ctr \n")
242  for(i <- 0 until RasSize){
243      XSDebug("  (%d)   0x%x      %d",i.U,spec_debug.out_mem(i).retAddr,spec_debug.out_mem(i).ctr)
244      when(i.U === spec_debug.sp){XSDebug(false,true.B,"   <----sp")}
245      XSDebug(false,true.B,"\n")
246  }
247  XSDebug(s2_spec_push, "s2_spec_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
248  s2_spec_new_addr,spec_debug.spec_push_entry.ctr,spec_debug.spec_alloc_new,spec_debug.sp.asUInt)
249  XSDebug(s2_spec_pop, "s2_spec_pop  outAddr: 0x%x \n",io.out.resp.s2.getTarget)
250  val s3_recover_entry = spec_debug.recover_push_entry
251  XSDebug(s3_recover && s3_push, "s3_recover_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
252    s3_recover_entry.retAddr, s3_recover_entry.ctr, spec_debug.recover_alloc_new, s3_sp.asUInt)
253  XSDebug(s3_recover && s3_pop, "s3_recover_pop  outAddr: 0x%x \n",io.out.resp.s3.getTarget)
254  val redirectUpdate = redirect.bits.cfiUpdate
255  XSDebug(do_recover && callMissPred, "redirect_recover_push\n")
256  XSDebug(do_recover && retMissPred, "redirect_recover_pop\n")
257  XSDebug(do_recover, "redirect_recover(SP:%d retAddr:%x ctr:%d) \n",
258      redirectUpdate.rasSp,redirectUpdate.rasEntry.retAddr,redirectUpdate.rasEntry.ctr)
259
260  generatePerfEvent()
261}
262