xref: /XiangShan/src/main/scala/xiangshan/frontend/RAS.scala (revision 3e52bed1735f5821bebb79dbb9b148cbb019938d)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.frontend
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import xiangshan._
25
26class RASEntry()(implicit p: Parameters) extends XSBundle {
27    val retAddr = UInt(VAddrBits.W)
28    val ctr = UInt(8.W) // layer of nested call functions
29}
30
31@chiselName
32class RAS(implicit p: Parameters) extends BasePredictor {
33  object RASEntry {
34    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
35      val e = Wire(new RASEntry)
36      e.retAddr := retAddr
37      e.ctr := ctr
38      e
39    }
40  }
41
42  @chiselName
43  class RASStack(val rasSize: Int) extends XSModule {
44    val io = IO(new Bundle {
45      val push_valid = Input(Bool())
46      val pop_valid = Input(Bool())
47      val spec_new_addr = Input(UInt(VAddrBits.W))
48
49      val recover_sp = Input(UInt(log2Up(rasSize).W))
50      val recover_top = Input(new RASEntry)
51      val recover_valid = Input(Bool())
52      val recover_push = Input(Bool())
53      val recover_pop = Input(Bool())
54      val recover_new_addr = Input(UInt(VAddrBits.W))
55
56      val sp = Output(UInt(log2Up(rasSize).W))
57      val top = Output(new RASEntry)
58    })
59
60    val debugIO = IO(new Bundle{
61        val push_entry = Output(new RASEntry)
62        val alloc_new = Output(Bool())
63        val sp = Output(UInt(log2Up(rasSize).W))
64        val topRegister = Output(new RASEntry)
65        val out_mem = Output(Vec(RasSize, new RASEntry))
66    })
67
68    val stack = Mem(RasSize, new RASEntry)
69    val sp = RegInit(0.U(log2Up(rasSize).W))
70    val top = RegInit(0.U.asTypeOf(new RASEntry))
71    val topPtr = RegInit(0.U(log2Up(rasSize).W))
72
73    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
74    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)
75
76    val alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
77    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR
78
79    // TODO: fix overflow and underflow bugs
80    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
81                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
82                              do_top: RASEntry) = {
83      when (do_push) {
84        when (do_alloc_new) {
85          sp     := ptrInc(do_sp)
86          topPtr := do_sp
87          top.retAddr := do_new_addr
88          top.ctr := 1.U
89          stack.write(do_sp, RASEntry(do_new_addr, 1.U))
90        }.otherwise {
91          when (recover) {
92            sp := do_sp
93            topPtr := do_top_ptr
94            top.retAddr := do_top.retAddr
95          }
96          top.ctr := do_top.ctr + 1.U
97          stack.write(do_top_ptr, RASEntry(do_new_addr, do_top.ctr + 1.U))
98        }
99      }.elsewhen (do_pop) {
100        when (do_top.ctr === 1.U) {
101          sp     := ptrDec(do_sp)
102          topPtr := ptrDec(do_top_ptr)
103          top := stack.read(ptrDec(do_top_ptr))
104        }.otherwise {
105          when (recover) {
106            sp := do_sp
107            topPtr := do_top_ptr
108            top.retAddr := do_top.retAddr
109          }
110          top.ctr := do_top.ctr - 1.U
111          stack.write(do_top_ptr, RASEntry(do_top.retAddr, do_top.ctr - 1.U))
112        }
113      }.otherwise {
114        when (recover) {
115          sp := do_sp
116          topPtr := do_top_ptr
117          top := do_top
118          stack.write(do_top_ptr, do_top)
119        }
120      }
121    }
122
123    update(io.recover_valid)(
124      Mux(io.recover_valid, io.recover_push,     io.push_valid),
125      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
126      Mux(io.recover_valid, recover_alloc_new,   alloc_new),
127      Mux(io.recover_valid, io.recover_sp,       sp),
128      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
129      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
130      Mux(io.recover_valid, io.recover_top,      top))
131
132    io.sp := sp
133    io.top := top
134
135    debugIO.push_entry := RASEntry(io.spec_new_addr, Mux(alloc_new, 1.U, top.ctr + 1.U))
136    debugIO.alloc_new := alloc_new
137    debugIO.sp := sp
138    debugIO.topRegister := top
139    for (i <- 0 until RasSize) {
140        debugIO.out_mem(i) := stack.read(i.U)
141    }
142  }
143
144  val spec = Module(new RASStack(RasSize))
145  val spec_ras = spec.io
146
147
148  val spec_push = WireInit(false.B)
149  val spec_pop = WireInit(false.B)
150  // val jump_is_first = io.callIdx.bits === 0.U
151  // val call_is_last_half = io.isLastHalfRVI && jump_is_first
152  // val spec_new_addr = packetAligned(io.pc.bits) + (io.callIdx.bits << instOffsetBits.U) + Mux( (io.isRVC | call_is_last_half) && HasCExtension.B, 2.U, 4.U)
153  val spec_new_addr = io.in.bits.resp_in(0).s2.preds.fallThroughAddr
154  spec_ras.push_valid := spec_push
155  spec_ras.pop_valid  := spec_pop
156  spec_ras.spec_new_addr := spec_new_addr
157  val spec_top_addr = spec_ras.top.retAddr
158
159  // confirm that the call/ret is the taken cfi
160  spec_push := io.s2_fire && io.in.bits.resp_in(0).s2.hit_taken_on_call
161  spec_pop  := io.s2_fire && io.in.bits.resp_in(0).s2.hit_taken_on_ret
162
163  when (spec_pop) {
164    io.out.resp.s2.preds.targets.last := spec_top_addr
165  }
166
167  io.out.resp.s2.rasSp  := spec_ras.sp
168  io.out.resp.s2.rasTop := spec_ras.top
169
170
171  val redirect = RegNext(io.redirect)
172  val do_recover = redirect.valid
173  val recover_cfi = redirect.bits.cfiUpdate
174
175  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
176  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
177  // when we mispredict a call, we must redo a push operation
178  // similarly, when we mispredict a return, we should redo a pop
179  spec_ras.recover_valid := do_recover
180  spec_ras.recover_push := callMissPred
181  spec_ras.recover_pop  := retMissPred
182
183  spec_ras.recover_sp  := recover_cfi.rasSp
184  spec_ras.recover_top := recover_cfi.rasEntry
185  spec_ras.recover_new_addr := recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U)
186
187  // TODO: back-up stack for ras
188  // use checkpoint to recover RAS
189
190  val spec_debug = spec.debugIO
191  XSDebug("----------------RAS----------------\n")
192  XSDebug(" TopRegister: 0x%x   %d \n",spec_debug.topRegister.retAddr,spec_debug.topRegister.ctr)
193  XSDebug("  index       addr           ctr \n")
194  for(i <- 0 until RasSize){
195      XSDebug("  (%d)   0x%x      %d",i.U,spec_debug.out_mem(i).retAddr,spec_debug.out_mem(i).ctr)
196      when(i.U === spec_debug.sp){XSDebug(false,true.B,"   <----sp")}
197      XSDebug(false,true.B,"\n")
198  }
199  XSDebug(spec_push, "(spec_ras)push  inAddr: 0x%x  inCtr: %d |  allocNewEntry:%d |   sp:%d \n",
200      spec_new_addr,spec_debug.push_entry.ctr,spec_debug.alloc_new,spec_debug.sp.asUInt)
201  XSDebug(spec_pop, "(spec_ras)pop  outAddr: 0x%x \n",io.out.resp.s2.target)
202  val redirectUpdate = redirect.bits.cfiUpdate
203  XSDebug("recoverValid:%d recover(SP:%d retAddr:%x ctr:%d) \n",
204      do_recover,redirectUpdate.rasSp,redirectUpdate.rasEntry.retAddr,redirectUpdate.rasEntry.ctr)
205}
206