xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision cf7d6b7a1a781c73aeb87de112de2e7fe5ea3b7c)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16/*
17package xiangshan.frontend
18
19import org.chipsalliance.cde.config.Parameters
20import chisel3._
21import chisel3.util._
22import utils._
23import utility._
24import xiangshan._
25
26import scala.{Tuple2 => &}
27
28
29class RASEntry()(implicit p: Parameters) extends XSBundle {
30    val retAddr = UInt(VAddrBits.W)
31    val ctr = UInt(8.W) // layer of nested call functions
32}
33
34class RAS(implicit p: Parameters) extends BasePredictor {
35  object RASEntry {
36    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
37      val e = Wire(new RASEntry)
38      e.retAddr := retAddr
39      e.ctr := ctr
40      e
41    }
42  }
43
44  class RASStack(val rasSize: Int) extends XSModule {
45    val io = IO(new Bundle {
46      val push_valid = Input(Bool())
47      val pop_valid = Input(Bool())
48      val spec_new_addr = Input(UInt(VAddrBits.W))
49
50      val recover_sp = Input(UInt(log2Up(rasSize).W))
51      val recover_top = Input(new RASEntry)
52      val recover_valid = Input(Bool())
53      val recover_push = Input(Bool())
54      val recover_pop = Input(Bool())
55      val recover_new_addr = Input(UInt(VAddrBits.W))
56
57      val sp = Output(UInt(log2Up(rasSize).W))
58      val top = Output(new RASEntry)
59    })
60
61    val debugIO = IO(new Bundle{
62        val spec_push_entry = Output(new RASEntry)
63        val spec_alloc_new = Output(Bool())
64        val recover_push_entry = Output(new RASEntry)
65        val recover_alloc_new = Output(Bool())
66        val sp = Output(UInt(log2Up(rasSize).W))
67        val topRegister = Output(new RASEntry)
68        val out_mem = Output(Vec(RasSize, new RASEntry))
69    })
70
71    val stack = Mem(RasSize, new RASEntry)
72    val sp = RegInit(0.U(log2Up(rasSize).W))
73    val top = RegInit(0.U.asTypeOf(new RASEntry()))
74    val topPtr = RegInit(0.U(log2Up(rasSize).W))
75
76    val wen = WireInit(false.B)
77    val write_bypass_entry = RegInit(0.U.asTypeOf(new RASEntry()))
78    val write_bypass_ptr = RegInit(0.U(log2Up(rasSize).W))
79    val write_bypass_valid = RegInit(false.B)
80    when (wen) {
81      write_bypass_valid := true.B
82    }.elsewhen (write_bypass_valid) {
83      write_bypass_valid := false.B
84    }
85
86    when (write_bypass_valid) {
87      stack(write_bypass_ptr) := write_bypass_entry
88    }
89
90    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
91    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)
92
93    val spec_alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
94    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR
95
96    // TODO: fix overflow and underflow bugs
97    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
98                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
99                              do_top: RASEntry) = {
100      when (do_push) {
101        when (do_alloc_new) {
102          sp     := ptrInc(do_sp)
103          topPtr := do_sp
104          top.retAddr := do_new_addr
105          top.ctr := 0.U
106          // write bypass
107          wen := true.B
108          write_bypass_entry := RASEntry(do_new_addr, 0.U)
109          write_bypass_ptr := do_sp
110        }.otherwise {
111          when (recover) {
112            sp := do_sp
113            topPtr := do_top_ptr
114            top.retAddr := do_top.retAddr
115          }
116          top.ctr := do_top.ctr + 1.U
117          // write bypass
118          wen := true.B
119          write_bypass_entry := RASEntry(do_new_addr, do_top.ctr + 1.U)
120          write_bypass_ptr := do_top_ptr
121        }
122      }.elsewhen (do_pop) {
123        when (do_top.ctr === 0.U) {
124          sp     := ptrDec(do_sp)
125          topPtr := ptrDec(do_top_ptr)
126          // read bypass
127          top :=
128            Mux(ptrDec(do_top_ptr) === write_bypass_ptr && write_bypass_valid,
129              write_bypass_entry,
130              stack.read(ptrDec(do_top_ptr))
131            )
132        }.otherwise {
133          when (recover) {
134            sp := do_sp
135            topPtr := do_top_ptr
136            top.retAddr := do_top.retAddr
137          }
138          top.ctr := do_top.ctr - 1.U
139          // write bypass
140          wen := true.B
141          write_bypass_entry := RASEntry(do_top.retAddr, do_top.ctr - 1.U)
142          write_bypass_ptr := do_top_ptr
143        }
144      }.otherwise {
145        when (recover) {
146          sp := do_sp
147          topPtr := do_top_ptr
148          top := do_top
149          // write bypass
150          wen := true.B
151          write_bypass_entry := do_top
152          write_bypass_ptr := do_top_ptr
153        }
154      }
155    }
156
157
158    update(io.recover_valid)(
159      Mux(io.recover_valid, io.recover_push,     io.push_valid),
160      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
161      Mux(io.recover_valid, recover_alloc_new,   spec_alloc_new),
162      Mux(io.recover_valid, io.recover_sp,       sp),
163      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
164      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
165      Mux(io.recover_valid, io.recover_top,      top))
166
167    io.sp := sp
168    io.top := top
169
170    val resetIdx = RegInit(0.U(log2Ceil(RasSize).W))
171    val do_reset = RegInit(true.B)
172    when (do_reset) {
173      stack.write(resetIdx, RASEntry(0x80000000L.U, 0.U))
174    }
175    resetIdx := resetIdx + do_reset
176    when (resetIdx === (RasSize-1).U) {
177      do_reset := false.B
178    }
179
180    debugIO.spec_push_entry := RASEntry(io.spec_new_addr, Mux(spec_alloc_new, 1.U, top.ctr + 1.U))
181    debugIO.spec_alloc_new := spec_alloc_new
182    debugIO.recover_push_entry := RASEntry(io.recover_new_addr, Mux(recover_alloc_new, 1.U, io.recover_top.ctr + 1.U))
183    debugIO.recover_alloc_new := recover_alloc_new
184    debugIO.sp := sp
185    debugIO.topRegister := top
186    for (i <- 0 until RasSize) {
187        debugIO.out_mem(i) := Mux(i.U === write_bypass_ptr && write_bypass_valid, write_bypass_entry, stack.read(i.U))
188    }
189  }
190
191  val spec = Module(new RASStack(RasSize))
192  val spec_ras = spec.io
193  val spec_top_addr = spec_ras.top.retAddr
194
195
196  val s2_spec_push = WireInit(false.B)
197  val s2_spec_pop = WireInit(false.B)
198  val s2_full_pred = io.in.bits.resp_in(0).s2.full_pred
199  // when last inst is an rvi call, fall through address would be set to the middle of it, so an addition is needed
200  val s2_spec_new_addr = s2_full_pred(2).fallThroughAddr + Mux(s2_full_pred(2).last_may_be_rvi_call, 2.U, 0.U)
201  spec_ras.push_valid := s2_spec_push
202  spec_ras.pop_valid  := s2_spec_pop
203  spec_ras.spec_new_addr := s2_spec_new_addr
204
205  // confirm that the call/ret is the taken cfi
206  s2_spec_push := io.s2_fire(2) && s2_full_pred(2).hit_taken_on_call && !io.s3_redirect(2)
207  s2_spec_pop  := io.s2_fire(2) && s2_full_pred(2).hit_taken_on_ret  && !io.s3_redirect(2)
208
209  val s2_jalr_target_dup = io.out.s2.full_pred.map(_.jalr_target)
210  val s2_last_target_in_dup = s2_full_pred.map(_.targets.last)
211  val s2_last_target_out_dup = io.out.s2.full_pred.map(_.targets.last)
212  val s2_is_jalr_dup = s2_full_pred.map(_.is_jalr)
213  val s2_is_ret_dup = s2_full_pred.map(_.is_ret)
214  // assert(is_jalr && is_ret || !is_ret)
215  val ras_enable_dup = dup(RegNext(io.ctrl.ras_enable))
216  for (ras_enable & s2_is_ret & s2_jalr_target <-
217    ras_enable_dup zip s2_is_ret_dup zip s2_jalr_target_dup) {
218      when(s2_is_ret && ras_enable) {
219        s2_jalr_target := spec_top_addr
220        // FIXME: should use s1 globally
221      }
222    }
223  for (s2_lto & s2_is_jalr & s2_jalr_target & s2_lti <-
224    s2_last_target_out_dup zip s2_is_jalr_dup zip s2_jalr_target_dup zip s2_last_target_in_dup) {
225      s2_lto := Mux(s2_is_jalr, s2_jalr_target, s2_lti)
226    }
227
228  val s3_top_dup = io.s2_fire.map(f => RegEnable(spec_ras.top, f))
229  val s3_sp = RegEnable(spec_ras.sp, io.s2_fire(2))
230  val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire(2))
231
232  val s3_full_pred = io.in.bits.resp_in(0).s3.full_pred
233  val s3_jalr_target_dup = io.out.s3.full_pred.map(_.jalr_target)
234  val s3_last_target_in_dup = s3_full_pred.map(_.targets.last)
235  val s3_last_target_out_dup = io.out.s3.full_pred.map(_.targets.last)
236  val s3_is_jalr_dup = s3_full_pred.map(_.is_jalr)
237  val s3_is_ret_dup = s3_full_pred.map(_.is_ret)
238  // assert(is_jalr && is_ret || !is_ret)
239
240  for (ras_enable & s3_is_ret & s3_jalr_target & s3_top <-
241    ras_enable_dup zip s3_is_ret_dup zip s3_jalr_target_dup zip s3_top_dup) {
242      when(s3_is_ret && ras_enable) {
243        s3_jalr_target := s3_top.retAddr
244        // FIXME: should use s1 globally
245      }
246    }
247  for (s3_lto & s3_is_jalr & s3_jalr_target & s3_lti <-
248    s3_last_target_out_dup zip s3_is_jalr_dup zip s3_jalr_target_dup zip s3_last_target_in_dup) {
249      s3_lto := Mux(s3_is_jalr, s3_jalr_target, s3_lti)
250    }
251
252  val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire(2))
253  val s3_popped_in_s2 = RegEnable(s2_spec_pop,  io.s2_fire(2))
254  val s3_push = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_call
255  val s3_pop  = io.in.bits.resp_in(0).s3.full_pred(2).hit_taken_on_ret
256
257  val s3_recover = io.s3_fire(2) && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop)
258  io.out.last_stage_spec_info.rasSp  := s3_sp
259  io.out.last_stage_spec_info.rasTop := s3_top_dup(2)
260
261
262  val redirect = RegNext(io.redirect)
263  val do_recover = redirect.valid || s3_recover
264  val recover_cfi = redirect.bits.cfiUpdate
265
266  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
267  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
268  // when we mispredict a call, we must redo a push operation
269  // similarly, when we mispredict a return, we should redo a pop
270  spec_ras.recover_valid := do_recover
271  spec_ras.recover_push := Mux(redirect.valid, callMissPred, s3_push)
272  spec_ras.recover_pop  := Mux(redirect.valid, retMissPred, s3_pop)
273
274  spec_ras.recover_sp  := Mux(redirect.valid, recover_cfi.rasSp, s3_sp)
275  spec_ras.recover_top := Mux(redirect.valid, recover_cfi.rasEntry, s3_top_dup(2))
276  spec_ras.recover_new_addr := Mux(redirect.valid, recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U), s3_spec_new_addr)
277
278
279  XSPerfAccumulate("ras_s3_recover", s3_recover)
280  XSPerfAccumulate("ras_redirect_recover", redirect.valid)
281  XSPerfAccumulate("ras_s3_and_redirect_recover_at_the_same_time", s3_recover && redirect.valid)
282  // TODO: back-up stack for ras
283  // use checkpoint to recover RAS
284
285  val spec_debug = spec.debugIO
286  XSDebug("----------------RAS----------------\n")
287  XSDebug(" TopRegister: 0x%x   %d \n",spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr)
288  XSDebug("  index       addr           ctr \n")
289  for(i <- 0 until RasSize){
290      XSDebug("  (%d)   0x%x      %d",i.U,spec_debug.out_mem(i).retAddr,spec_debug.out_mem(i).ctr)
291      when(i.U === spec_debug.sp){XSDebug(false,true.B,"   <----sp")}
292      XSDebug(false,true.B,"\n")
293  }
294  XSDebug(s2_spec_push, "s2_spec_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
295  s2_spec_new_addr,spec_debug.spec_push_entry.ctr,spec_debug.spec_alloc_new,spec_debug.sp.asUInt)
296  XSDebug(s2_spec_pop, "s2_spec_pop  outAddr: 0x%x \n",io.out.s2.getTarget(2))
297  val s3_recover_entry = spec_debug.recover_push_entry
298  XSDebug(s3_recover && s3_push, "s3_recover_push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
299    s3_recover_entry.retAddr, s3_recover_entry.ctr, spec_debug.recover_alloc_new, s3_sp.asUInt)
300  XSDebug(s3_recover && s3_pop, "s3_recover_pop  outAddr: 0x%x \n",io.out.s3.getTarget(2))
301  val redirectUpdate = redirect.bits.cfiUpdate
302  XSDebug(do_recover && callMissPred, "redirect_recover_push\n")
303  XSDebug(do_recover && retMissPred, "redirect_recover_pop\n")
304  XSDebug(do_recover, "redirect_recover(SP:%d retAddr:%x ctr:%d) \n",
305      redirectUpdate.rasSp,redirectUpdate.rasEntry.retAddr,redirectUpdate.rasEntry.ctr)
306
307  generatePerfEvent()
308}
309 */
310