xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision 5df98e433cd3f6d01298881d7bad26b95a38989d)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.frontend
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import xiangshan._
25
26class RASEntry()(implicit p: Parameters) extends XSBundle {
27    val retAddr = UInt(VAddrBits.W)
28    val ctr = UInt(8.W) // layer of nested call functions
29}
30
31@chiselName
32class RAS(implicit p: Parameters) extends BasePredictor {
33  object RASEntry {
34    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
35      val e = Wire(new RASEntry)
36      e.retAddr := retAddr
37      e.ctr := ctr
38      e
39    }
40  }
41
42  @chiselName
43  class RASStack(val rasSize: Int) extends XSModule {
44    val io = IO(new Bundle {
45      val push_valid = Input(Bool())
46      val pop_valid = Input(Bool())
47      val spec_new_addr = Input(UInt(VAddrBits.W))
48
49      val recover_sp = Input(UInt(log2Up(rasSize).W))
50      val recover_top = Input(new RASEntry)
51      val recover_valid = Input(Bool())
52      val recover_push = Input(Bool())
53      val recover_pop = Input(Bool())
54      val recover_new_addr = Input(UInt(VAddrBits.W))
55
56      val sp = Output(UInt(log2Up(rasSize).W))
57      val top = Output(new RASEntry)
58    })
59
60    val debugIO = IO(new Bundle{
61        val push_entry = Output(new RASEntry)
62        val alloc_new = Output(Bool())
63        val sp = Output(UInt(log2Up(rasSize).W))
64        val topRegister = Output(new RASEntry)
65        val out_mem = Output(Vec(RasSize, new RASEntry))
66    })
67
68    val stack = Mem(RasSize, new RASEntry)
69    val sp = RegInit(0.U(log2Up(rasSize).W))
70    val top = RegInit(0.U.asTypeOf(new RASEntry))
71    val topPtr = RegInit(0.U(log2Up(rasSize).W))
72
73    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
74    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)
75
76    val alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
77    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR
78
79    // TODO: fix overflow and underflow bugs
80    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
81                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
82                              do_top: RASEntry) = {
83      when (do_push) {
84        when (do_alloc_new) {
85          sp     := ptrInc(do_sp)
86          topPtr := do_sp
87          top.retAddr := do_new_addr
88          top.ctr := 1.U
89          stack.write(do_sp, RASEntry(do_new_addr, 1.U))
90        }.otherwise {
91          when (recover) {
92            sp := do_sp
93            topPtr := do_top_ptr
94            top.retAddr := do_top.retAddr
95          }
96          top.ctr := do_top.ctr + 1.U
97          stack.write(do_top_ptr, RASEntry(do_new_addr, do_top.ctr + 1.U))
98        }
99      }.elsewhen (do_pop) {
100        when (do_top.ctr === 1.U) {
101          sp     := ptrDec(do_sp)
102          topPtr := ptrDec(do_top_ptr)
103          top := stack.read(ptrDec(do_top_ptr))
104        }.otherwise {
105          when (recover) {
106            sp := do_sp
107            topPtr := do_top_ptr
108            top.retAddr := do_top.retAddr
109          }
110          top.ctr := do_top.ctr - 1.U
111          stack.write(do_top_ptr, RASEntry(do_top.retAddr, do_top.ctr - 1.U))
112        }
113      }.otherwise {
114        when (recover) {
115          sp := do_sp
116          topPtr := do_top_ptr
117          top := do_top
118          stack.write(do_top_ptr, do_top)
119        }
120      }
121    }
122
123
124    update(io.recover_valid)(
125      Mux(io.recover_valid, io.recover_push,     io.push_valid),
126      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
127      Mux(io.recover_valid, recover_alloc_new,   alloc_new),
128      Mux(io.recover_valid, io.recover_sp,       sp),
129      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
130      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
131      Mux(io.recover_valid, io.recover_top,      top))
132
133    io.sp := sp
134    io.top := top
135
136    val resetIdx = RegInit(0.U(log2Ceil(RasSize).W))
137    val do_reset = RegInit(true.B)
138    when (do_reset) {
139      stack.write(resetIdx, RASEntry(0x80000000L.U, 0.U))
140    }
141    resetIdx := resetIdx + do_reset
142    when (resetIdx === (RasSize-1).U) {
143      do_reset := false.B
144    }
145
146    debugIO.push_entry := RASEntry(io.spec_new_addr, Mux(alloc_new, 1.U, top.ctr + 1.U))
147    debugIO.alloc_new := alloc_new
148    debugIO.sp := sp
149    debugIO.topRegister := top
150    for (i <- 0 until RasSize) {
151        debugIO.out_mem(i) := stack.read(i.U)
152    }
153  }
154
155  val spec = Module(new RASStack(RasSize))
156  val spec_ras = spec.io
157  val spec_top_addr = spec_ras.top.retAddr
158
159
160  val s2_spec_push = WireInit(false.B)
161  val s2_spec_pop = WireInit(false.B)
162  // val jump_is_first = io.callIdx.bits === 0.U
163  // val call_is_last_half = io.isLastHalfRVI && jump_is_first
164  // val spec_new_addr = packetAligned(io.pc.bits) + (io.callIdx.bits << instOffsetBits.U) + Mux( (io.isRVC | call_is_last_half) && HasCExtension.B, 2.U, 4.U)
165  val s2_spec_new_addr = io.in.bits.resp_in(0).s2.full_pred.fallThroughAddr
166  spec_ras.push_valid := s2_spec_push
167  spec_ras.pop_valid  := s2_spec_pop
168  spec_ras.spec_new_addr := s2_spec_new_addr
169
170  // confirm that the call/ret is the taken cfi
171  s2_spec_push := io.s2_fire && io.in.bits.resp_in(0).s2.full_pred.hit_taken_on_call
172  s2_spec_pop  := io.s2_fire && io.in.bits.resp_in(0).s2.full_pred.hit_taken_on_ret
173
174  val s2_jalr_target = io.out.resp.s2.full_pred.jalr_target
175  val s2_last_target_in = io.in.bits.resp_in(0).s2.full_pred.targets.last
176  val s2_last_target_out = io.out.resp.s2.full_pred.targets.last
177  val s2_is_jalr = io.in.bits.resp_in(0).s2.full_pred.is_jalr
178  val s2_is_ret = io.in.bits.resp_in(0).s2.full_pred.is_ret
179  // assert(is_jalr && is_ret || !is_ret)
180  when(s2_is_ret) {
181    s2_jalr_target := spec_top_addr
182    // FIXME: should use s1 globally
183  }
184  s2_last_target_out := Mux(s2_is_jalr, s2_jalr_target, s2_last_target_in)
185
186  val s3_top = RegEnable(spec_ras.top, io.s2_fire)
187  val s3_sp = RegEnable(spec_ras.sp, io.s2_fire)
188  val s3_spec_new_addr = RegEnable(s2_spec_new_addr, io.s2_fire)
189
190  val s3_jalr_target = io.out.resp.s3.full_pred.jalr_target
191  val s3_last_target_in = io.in.bits.resp_in(0).s3.full_pred.targets.last
192  val s3_last_target_out = io.out.resp.s3.full_pred.targets.last
193  val s3_is_jalr = io.in.bits.resp_in(0).s3.full_pred.is_jalr
194  val s3_is_ret = io.in.bits.resp_in(0).s3.full_pred.is_ret
195  // assert(is_jalr && is_ret || !is_ret)
196  when(s3_is_ret) {
197    s3_jalr_target := s3_top.retAddr
198    // FIXME: should use s1 globally
199  }
200  s3_last_target_out := Mux(s3_is_jalr, s3_jalr_target, s3_last_target_in)
201
202  val s3_pushed_in_s2 = RegEnable(s2_spec_push, io.s2_fire)
203  val s3_popped_in_s2 = RegEnable(s2_spec_pop,  io.s2_fire)
204  val s3_push = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_call
205  val s3_pop  = io.in.bits.resp_in(0).s3.full_pred.hit_taken_on_ret
206
207  val s3_recover = io.s3_fire && (s3_pushed_in_s2 =/= s3_push || s3_popped_in_s2 =/= s3_pop)
208  io.out.resp.s3.rasSp  := s3_sp
209  io.out.resp.s3.rasTop := s3_top
210
211
212  val redirect = RegNext(io.redirect)
213  val do_recover = redirect.valid || s3_recover
214  val recover_cfi = redirect.bits.cfiUpdate
215
216  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
217  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
218  // when we mispredict a call, we must redo a push operation
219  // similarly, when we mispredict a return, we should redo a pop
220  spec_ras.recover_valid := do_recover
221  spec_ras.recover_push := Mux(redirect.valid, callMissPred, s3_push)
222  spec_ras.recover_pop  := Mux(redirect.valid, retMissPred, s3_pop)
223
224  spec_ras.recover_sp  := Mux(redirect.valid, recover_cfi.rasSp, s3_sp)
225  spec_ras.recover_top := Mux(redirect.valid, recover_cfi.rasEntry, s3_top)
226  spec_ras.recover_new_addr := Mux(redirect.valid, recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U), s3_spec_new_addr)
227
228
229  XSPerfAccumulate("ras_s3_recover", s3_recover)
230  XSPerfAccumulate("ras_redirect_recover", redirect.valid)
231  XSPerfAccumulate("ras_s3_and_redirect_recover_at_the_same_time", s3_recover && redirect.valid)
232  // TODO: back-up stack for ras
233  // use checkpoint to recover RAS
234
235  val spec_debug = spec.debugIO
236  XSDebug("----------------RAS----------------\n")
237  XSDebug(" TopRegister: 0x%x   %d \n",spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr)
238  XSDebug("  index       addr           ctr \n")
239  for(i <- 0 until RasSize){
240      XSDebug("  (%d)   0x%x      %d",i.U,spec_debug.out_mem(i).retAddr,spec_debug.out_mem(i).ctr)
241      when(i.U === spec_debug.sp){XSDebug(false,true.B,"   <----sp")}
242      XSDebug(false,true.B,"\n")
243  }
244  XSDebug(s2_spec_push, "(spec_ras)push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
245      s2_spec_new_addr,spec_debug.push_entry.ctr,spec_debug.alloc_new,spec_debug.sp.asUInt)
246  XSDebug(s2_spec_pop, "(spec_ras)pop  outAddr: 0x%x \n",io.out.resp.s2.getTarget)
247  val redirectUpdate = redirect.bits.cfiUpdate
248  XSDebug("recoverValid:%d recover(SP:%d retAddr:%x ctr:%d) \n",
249      do_recover,redirectUpdate.rasSp,redirectUpdate.rasEntry.retAddr,redirectUpdate.rasEntry.ctr)
250}
251