xref: /XiangShan/src/main/scala/xiangshan/backend/fu/CSR.scala (revision 5c5bd416ce761d956348a8e2fbbf268922371d8b)
1package xiangshan.backend.fu
2
3import chisel3._
4import chisel3.ExcitingUtils.{ConnectionType, Debug}
5import chisel3.util._
6import utils._
7import xiangshan._
8import xiangshan.backend._
9import xiangshan.frontend.BPUCtrl
10import xiangshan.backend.fu.util._
11
12trait HasExceptionNO {
13  def instrAddrMisaligned = 0
14  def instrAccessFault    = 1
15  def illegalInstr        = 2
16  def breakPoint          = 3
17  def loadAddrMisaligned  = 4
18  def loadAccessFault     = 5
19  def storeAddrMisaligned = 6
20  def storeAccessFault    = 7
21  def ecallU              = 8
22  def ecallS              = 9
23  def ecallM              = 11
24  def instrPageFault      = 12
25  def loadPageFault       = 13
26  def storePageFault      = 15
27
28  val ExcPriority = Seq(
29    breakPoint, // TODO: different BP has different priority
30    instrPageFault,
31    instrAccessFault,
32    illegalInstr,
33    instrAddrMisaligned,
34    ecallM, ecallS, ecallU,
35    storePageFault,
36    loadPageFault,
37    storeAccessFault,
38    loadAccessFault,
39    storeAddrMisaligned,
40    loadAddrMisaligned
41  )
42  val frontendSet = List(
43    // instrAddrMisaligned,
44    instrAccessFault,
45    illegalInstr,
46    instrPageFault
47  )
48  val csrSet = List(
49    illegalInstr,
50    breakPoint,
51    ecallU,
52    ecallS,
53    ecallM
54  )
55  val loadUnitSet = List(
56    loadAddrMisaligned,
57    loadAccessFault,
58    loadPageFault
59  )
60  val storeUnitSet = List(
61    storeAddrMisaligned,
62    storeAccessFault,
63    storePageFault
64  )
65  val atomicsUnitSet = (loadUnitSet ++ storeUnitSet).distinct
66  val allPossibleSet = (frontendSet ++ csrSet ++ loadUnitSet ++ storeUnitSet).distinct
67  val csrWbCount = (0 until 16).map(i => if (csrSet.contains(i)) 1 else 0)
68  val loadWbCount = (0 until 16).map(i => if (loadUnitSet.contains(i)) 1 else 0)
69  val storeWbCount = (0 until 16).map(i => if (storeUnitSet.contains(i)) 1 else 0)
70  val atomicsWbCount = (0 until 16).map(i => if (atomicsUnitSet.contains(i)) 1 else 0)
71  val writebackCount = (0 until 16).map(i => csrWbCount(i) + atomicsWbCount(i) + loadWbCount(i) + 2 * storeWbCount(i))
72  def partialSelect(vec: Vec[Bool], select: Seq[Int], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] = {
73    if (dontCareBits) {
74      val new_vec = Wire(ExceptionVec())
75      new_vec := DontCare
76      select.map(i => new_vec(i) := vec(i))
77      return new_vec
78    }
79    else if (falseBits) {
80      val new_vec = Wire(ExceptionVec())
81      new_vec.map(_ := false.B)
82      select.map(i => new_vec(i) := vec(i))
83      return new_vec
84    }
85    else {
86      val new_vec = Wire(Vec(select.length, Bool()))
87      select.zipWithIndex.map{ case(s, i) => new_vec(i) := vec(s) }
88      return new_vec
89    }
90  }
91  def selectFrontend(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
92    partialSelect(vec, frontendSet, dontCareBits, falseBits)
93  def selectCSR(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
94    partialSelect(vec, csrSet, dontCareBits, falseBits)
95  def selectLoad(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
96    partialSelect(vec, loadUnitSet, dontCareBits, falseBits)
97  def selectStore(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
98    partialSelect(vec, storeUnitSet, dontCareBits, falseBits)
99  def selectAtomics(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
100    partialSelect(vec, atomicsUnitSet, dontCareBits, falseBits)
101  def selectAll(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
102    partialSelect(vec, allPossibleSet, dontCareBits, falseBits)
103}
104
105class FpuCsrIO extends XSBundle {
106  val fflags = Output(Valid(UInt(5.W)))
107  val isIllegal = Output(Bool())
108  val dirty_fs = Output(Bool())
109  val frm = Input(UInt(3.W))
110}
111
112
113class PerfCounterIO extends XSBundle {
114  val retiredInstr = Input(UInt(3.W))
115  val value = Input(UInt(XLEN.W))
116}
117
118class CSRFileIO extends XSBundle {
119  val hartId = Input(UInt(64.W))
120  // output (for func === CSROpType.jmp)
121  val perf = new PerfCounterIO
122  val isPerfCnt = Output(Bool())
123  // to FPU
124  val fpu = Flipped(new FpuCsrIO)
125  // from rob
126  val exception = Flipped(ValidIO(new ExceptionInfo))
127  // to ROB
128  val isXRet = Output(Bool())
129  val trapTarget = Output(UInt(VAddrBits.W))
130  val interrupt = Output(Bool())
131  // from LSQ
132  val memExceptionVAddr = Input(UInt(VAddrBits.W))
133  // from outside cpu,externalInterrupt
134  val externalInterrupt = new ExternalInterruptIO
135  // TLB
136  val tlb = Output(new TlbCsrBundle)
137  // Custom microarchiture ctrl signal
138  val customCtrl = Output(new CustomCSRCtrlIO)
139}
140
141class CSR extends FunctionUnit with HasCSRConst
142{
143  val csrio = IO(new CSRFileIO)
144  val difftestIO = IO(new Bundle() {
145    val intrNO = Output(UInt(64.W))
146    val cause = Output(UInt(64.W))
147    val priviledgeMode = Output(UInt(2.W))
148    val mstatus = Output(UInt(64.W))
149    val sstatus = Output(UInt(64.W))
150    val mepc = Output(UInt(64.W))
151    val sepc = Output(UInt(64.W))
152    val mtval = Output(UInt(64.W))
153    val stval = Output(UInt(64.W))
154    val mtvec = Output(UInt(64.W))
155    val stvec = Output(UInt(64.W))
156    val mcause = Output(UInt(64.W))
157    val scause = Output(UInt(64.W))
158    val satp = Output(UInt(64.W))
159    val mip = Output(UInt(64.W))
160    val mie = Output(UInt(64.W))
161    val mscratch = Output(UInt(64.W))
162    val sscratch = Output(UInt(64.W))
163    val mideleg = Output(UInt(64.W))
164    val medeleg = Output(UInt(64.W))
165  })
166  difftestIO <> DontCare
167
168  val cfIn = io.in.bits.uop.cf
169  val cfOut = Wire(new CtrlFlow)
170  cfOut := cfIn
171  val flushPipe = Wire(Bool())
172
173  val (valid, src1, src2, func) = (
174    io.in.valid,
175    io.in.bits.src(0),
176    io.in.bits.uop.ctrl.imm,
177    io.in.bits.uop.ctrl.fuOpType
178  )
179
180  // CSR define
181
182  class Priv extends Bundle {
183    val m = Output(Bool())
184    val h = Output(Bool())
185    val s = Output(Bool())
186    val u = Output(Bool())
187  }
188
189  val csrNotImplemented = RegInit(UInt(XLEN.W), 0.U)
190
191  class MstatusStruct extends Bundle {
192    val sd = Output(UInt(1.W))
193
194    val pad1 = if (XLEN == 64) Output(UInt(27.W)) else null
195    val sxl  = if (XLEN == 64) Output(UInt(2.W))  else null
196    val uxl  = if (XLEN == 64) Output(UInt(2.W))  else null
197    val pad0 = if (XLEN == 64) Output(UInt(9.W))  else Output(UInt(8.W))
198
199    val tsr = Output(UInt(1.W))
200    val tw = Output(UInt(1.W))
201    val tvm = Output(UInt(1.W))
202    val mxr = Output(UInt(1.W))
203    val sum = Output(UInt(1.W))
204    val mprv = Output(UInt(1.W))
205    val xs = Output(UInt(2.W))
206    val fs = Output(UInt(2.W))
207    val mpp = Output(UInt(2.W))
208    val hpp = Output(UInt(2.W))
209    val spp = Output(UInt(1.W))
210    val pie = new Priv
211    val ie = new Priv
212    assert(this.getWidth == XLEN)
213  }
214
215  class SatpStruct extends Bundle {
216    val mode = UInt(4.W)
217    val asid = UInt(16.W)
218    val ppn  = UInt(44.W)
219  }
220
221  class Interrupt extends Bundle {
222    val e = new Priv
223    val t = new Priv
224    val s = new Priv
225  }
226
227  // Machine-Level CSRs
228
229  val mtvec = RegInit(UInt(XLEN.W), 0.U)
230  val mcounteren = RegInit(UInt(XLEN.W), 0.U)
231  val mcause = RegInit(UInt(XLEN.W), 0.U)
232  val mtval = RegInit(UInt(XLEN.W), 0.U)
233  val mepc = Reg(UInt(XLEN.W))
234
235  val mie = RegInit(0.U(XLEN.W))
236  val mipWire = WireInit(0.U.asTypeOf(new Interrupt))
237  val mipReg  = RegInit(0.U.asTypeOf(new Interrupt).asUInt)
238  val mipFixMask = GenMask(9) | GenMask(5) | GenMask(1)
239  val mip = (mipWire.asUInt | mipReg).asTypeOf(new Interrupt)
240
241  def getMisaMxl(mxl: Int): UInt = {mxl.U << (XLEN-2)}.asUInt()
242  def getMisaExt(ext: Char): UInt = {1.U << (ext.toInt - 'a'.toInt)}.asUInt()
243  var extList = List('a', 's', 'i', 'u')
244  if (HasMExtension) { extList = extList :+ 'm' }
245  if (HasCExtension) { extList = extList :+ 'c' }
246  if (HasFPU) { extList = extList ++ List('f', 'd') }
247  val misaInitVal = getMisaMxl(2) | extList.foldLeft(0.U)((sum, i) => sum | getMisaExt(i)) //"h8000000000141105".U
248  val misa = RegInit(UInt(XLEN.W), misaInitVal)
249
250  // MXL = 2          | 0 | EXT = b 00 0000 0100 0001 0001 0000 0101
251  // (XLEN-1, XLEN-2) |   |(25, 0)  ZY XWVU TSRQ PONM LKJI HGFE DCBA
252
253  val mvendorid = RegInit(UInt(XLEN.W), 0.U) // this is a non-commercial implementation
254  val marchid = RegInit(UInt(XLEN.W), 0.U) // return 0 to indicate the field is not implemented
255  val mimpid = RegInit(UInt(XLEN.W), 0.U) // provides a unique encoding of the version of the processor implementation
256  val mhartid = RegInit(UInt(XLEN.W), csrio.hartId) // the hardware thread running the code
257  val mstatus = RegInit(UInt(XLEN.W), 0.U)
258
259  // mstatus Value Table
260  // | sd   |
261  // | pad1 |
262  // | sxl  | hardlinked to 10, use 00 to pass xv6 test
263  // | uxl  | hardlinked to 00
264  // | pad0 |
265  // | tsr  |
266  // | tw   |
267  // | tvm  |
268  // | mxr  |
269  // | sum  |
270  // | mprv |
271  // | xs   | 00 |
272  // | fs   | 00 |
273  // | mpp  | 00 |
274  // | hpp  | 00 |
275  // | spp  | 0 |
276  // | pie  | 0000 | pie.h is used as UBE
277  // | ie   | 0000 | uie hardlinked to 0, as N ext is not implemented
278
279  val mstatusStruct = mstatus.asTypeOf(new MstatusStruct)
280  def mstatusUpdateSideEffect(mstatus: UInt): UInt = {
281    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
282    val mstatusNew = Cat(mstatusOld.xs === "b11".U || mstatusOld.fs === "b11".U, mstatus(XLEN-2, 0))
283    mstatusNew
284  }
285
286  val mstatusMask = (~ZeroExt((
287    GenMask(XLEN-2, 38) | GenMask(31, 23) | GenMask(10, 9) | GenMask(2) |
288    GenMask(37) | // MBE
289    GenMask(36) | // SBE
290    GenMask(6)    // UBE
291  ), 64)).asUInt()
292
293  val medeleg = RegInit(UInt(XLEN.W), 0.U)
294  val mideleg = RegInit(UInt(XLEN.W), 0.U)
295  val mscratch = RegInit(UInt(XLEN.W), 0.U)
296
297  val pmpcfg0 = RegInit(UInt(XLEN.W), 0.U)
298  val pmpcfg1 = RegInit(UInt(XLEN.W), 0.U)
299  val pmpcfg2 = RegInit(UInt(XLEN.W), 0.U)
300  val pmpcfg3 = RegInit(UInt(XLEN.W), 0.U)
301  val pmpaddr0 = RegInit(UInt(XLEN.W), 0.U)
302  val pmpaddr1 = RegInit(UInt(XLEN.W), 0.U)
303  val pmpaddr2 = RegInit(UInt(XLEN.W), 0.U)
304  val pmpaddr3 = RegInit(UInt(XLEN.W), 0.U)
305
306  // Superviser-Level CSRs
307
308  // val sstatus = RegInit(UInt(XLEN.W), "h00000000".U)
309  val sstatusWmask = "hc6122".U
310  // Sstatus Write Mask
311  // -------------------------------------------------------
312  //    19           9   5     2
313  // 0  1100 0000 0001 0010 0010
314  // 0  c    0    1    2    2
315  // -------------------------------------------------------
316  val sstatusRmask = sstatusWmask | "h8000000300018000".U
317  // Sstatus Read Mask = (SSTATUS_WMASK | (0xf << 13) | (1ull << 63) | (3ull << 32))
318  val stvec = RegInit(UInt(XLEN.W), 0.U)
319  // val sie = RegInit(0.U(XLEN.W))
320  val sieMask = "h222".U & mideleg
321  val sipMask  = "h222".U & mideleg
322  val satp = if(EnbaleTlbDebug) RegInit(UInt(XLEN.W), "h8000000000087fbe".U) else RegInit(0.U(XLEN.W))
323  // val satp = RegInit(UInt(XLEN.W), "h8000000000087fbe".U) // only use for tlb naive debug
324  val satpMask = "h80000fffffffffff".U // disable asid, mode can only be 8 / 0
325  val sepc = RegInit(UInt(XLEN.W), 0.U)
326  val scause = RegInit(UInt(XLEN.W), 0.U)
327  val stval = Reg(UInt(XLEN.W))
328  val sscratch = RegInit(UInt(XLEN.W), 0.U)
329  val scounteren = RegInit(UInt(XLEN.W), 0.U)
330
331  // sbpctl
332  // Bits 0-7: {LOOP, RAS, SC, TAGE, BIM, BTB, uBTB}
333  val sbpctl = RegInit(UInt(XLEN.W), "h7f".U)
334  csrio.customCtrl.bp_ctrl.ubtb_enable := sbpctl(0)
335  csrio.customCtrl.bp_ctrl.btb_enable  := sbpctl(1)
336  csrio.customCtrl.bp_ctrl.bim_enable  := sbpctl(2)
337  csrio.customCtrl.bp_ctrl.tage_enable := sbpctl(3)
338  csrio.customCtrl.bp_ctrl.sc_enable   := sbpctl(4)
339  csrio.customCtrl.bp_ctrl.ras_enable  := sbpctl(5)
340  csrio.customCtrl.bp_ctrl.loop_enable := sbpctl(6)
341
342  // spfctl Bit 0: L1plusCache Prefetcher Enable
343  // spfctl Bit 1: L2Cache Prefetcher Enable
344  val spfctl = RegInit(UInt(XLEN.W), "h3".U)
345  csrio.customCtrl.l1plus_pf_enable := spfctl(0)
346  csrio.customCtrl.l2_pf_enable := spfctl(1)
347
348  // sdsid: Differentiated Services ID
349  val sdsid = RegInit(UInt(XLEN.W), 0.U)
350  csrio.customCtrl.dsid := sdsid
351
352  // slvpredctl: load violation predict settings
353  val slvpredctl = RegInit(UInt(XLEN.W), "h70".U) // default reset period: 2^17
354  csrio.customCtrl.lvpred_disable := slvpredctl(0)
355  csrio.customCtrl.no_spec_load := slvpredctl(1)
356  csrio.customCtrl.waittable_timeout := slvpredctl(8, 4)
357
358  // smblockctl: memory block configurations
359  // bits 0-3: store buffer flush threshold (default: 8 entries)
360  val smblockctl = RegInit(UInt(XLEN.W), "h7".U)
361  csrio.customCtrl.sbuffer_threshold := smblockctl(3, 0)
362
363  val srnctl = RegInit(UInt(XLEN.W), "h1".U)
364  csrio.customCtrl.move_elim_enable := srnctl(0)
365
366  val tlbBundle = Wire(new TlbCsrBundle)
367  tlbBundle.satp := satp.asTypeOf(new SatpStruct)
368  csrio.tlb := tlbBundle
369
370  // User-Level CSRs
371  val uepc = Reg(UInt(XLEN.W))
372
373  // fcsr
374  class FcsrStruct extends Bundle {
375    val reserved = UInt((XLEN-3-5).W)
376    val frm = UInt(3.W)
377    val fflags = UInt(5.W)
378    assert(this.getWidth == XLEN)
379  }
380  val fcsr = RegInit(0.U(XLEN.W))
381  // set mstatus->sd and mstatus->fs when true
382  val csrw_dirty_fp_state = WireInit(false.B)
383
384  def frm_wfn(wdata: UInt): UInt = {
385    val fcsrOld = WireInit(fcsr.asTypeOf(new FcsrStruct))
386    csrw_dirty_fp_state := true.B
387    fcsrOld.frm := wdata(2,0)
388    fcsrOld.asUInt()
389  }
390  def frm_rfn(rdata: UInt): UInt = rdata(7,5)
391
392  def fflags_wfn(update: Boolean)(wdata: UInt): UInt = {
393    val fcsrOld = fcsr.asTypeOf(new FcsrStruct)
394    val fcsrNew = WireInit(fcsrOld)
395    csrw_dirty_fp_state := true.B
396    if (update) {
397      fcsrNew.fflags := wdata(4,0) | fcsrOld.fflags
398    } else {
399      fcsrNew.fflags := wdata(4,0)
400    }
401    fcsrNew.asUInt()
402  }
403  def fflags_rfn(rdata:UInt): UInt = rdata(4,0)
404
405  def fcsr_wfn(wdata: UInt): UInt = {
406    val fcsrOld = WireInit(fcsr.asTypeOf(new FcsrStruct))
407    csrw_dirty_fp_state := true.B
408    Cat(fcsrOld.reserved, wdata.asTypeOf(fcsrOld).frm, wdata.asTypeOf(fcsrOld).fflags)
409  }
410
411  val fcsrMapping = Map(
412    MaskedRegMap(Fflags, fcsr, wfn = fflags_wfn(update = false), rfn = fflags_rfn),
413    MaskedRegMap(Frm, fcsr, wfn = frm_wfn, rfn = frm_rfn),
414    MaskedRegMap(Fcsr, fcsr, wfn = fcsr_wfn)
415  )
416
417  // Atom LR/SC Control Bits
418  //  val setLr = WireInit(Bool(), false.B)
419  //  val setLrVal = WireInit(Bool(), false.B)
420  //  val setLrAddr = WireInit(UInt(AddrBits.W), DontCare) //TODO : need check
421  //  val lr = RegInit(Bool(), false.B)
422  //  val lrAddr = RegInit(UInt(AddrBits.W), 0.U)
423  //
424  //  when (setLr) {
425  //    lr := setLrVal
426  //    lrAddr := setLrAddr
427  //  }
428
429  // Hart Priviledge Mode
430  val priviledgeMode = RegInit(UInt(2.W), ModeM)
431
432  // Emu perfcnt
433  val hasEmuPerfCnt = !env.FPGAPlatform
434  val nrEmuPerfCnts = if (hasEmuPerfCnt) 0x80 else 0x3
435
436  val emuPerfCnts    = List.fill(nrEmuPerfCnts)(RegInit(0.U(XLEN.W)))
437  val emuPerfCntCond = List.fill(nrEmuPerfCnts)(WireInit(false.B))
438  (emuPerfCnts zip emuPerfCntCond).map { case (c, e) => when (e) { c := c + 1.U } }
439
440  val emuPerfCntsLoMapping = (0 until nrEmuPerfCnts).map(i => MaskedRegMap(0x1000 + i, emuPerfCnts(i)))
441  val emuPerfCntsHiMapping = (0 until nrEmuPerfCnts).map(i => MaskedRegMap(0x1080 + i, emuPerfCnts(i)(63, 32)))
442  println(s"CSR: hasEmuPerfCnt:${hasEmuPerfCnt}")
443
444  // Perf Counter
445  val nrPerfCnts = 29  // 3...31
446  val perfCnts   = List.fill(nrPerfCnts)(RegInit(0.U(XLEN.W)))
447  val perfEvents = List.fill(nrPerfCnts)(RegInit(0.U(XLEN.W)))
448  val mcountinhibit = RegInit(0.U(XLEN.W))
449  val mcycle = RegInit(0.U(XLEN.W))
450  mcycle := mcycle + 1.U
451  val minstret = RegInit(0.U(XLEN.W))
452  minstret := minstret + RegNext(csrio.perf.retiredInstr)
453
454  // CSR reg map
455  val basicPrivMapping = Map(
456
457    //--- User Trap Setup ---
458    // MaskedRegMap(Ustatus, ustatus),
459    // MaskedRegMap(Uie, uie, 0.U, MaskedRegMap.Unwritable),
460    // MaskedRegMap(Utvec, utvec),
461
462    //--- User Trap Handling ---
463    // MaskedRegMap(Uscratch, uscratch),
464    // MaskedRegMap(Uepc, uepc),
465    // MaskedRegMap(Ucause, ucause),
466    // MaskedRegMap(Utval, utval),
467    // MaskedRegMap(Uip, uip),
468
469    //--- User Counter/Timers ---
470    // MaskedRegMap(Cycle, cycle),
471    // MaskedRegMap(Time, time),
472    // MaskedRegMap(Instret, instret),
473
474    //--- Supervisor Trap Setup ---
475    MaskedRegMap(Sstatus, mstatus, sstatusWmask, mstatusUpdateSideEffect, sstatusRmask),
476    // MaskedRegMap(Sedeleg, Sedeleg),
477    // MaskedRegMap(Sideleg, Sideleg),
478    MaskedRegMap(Sie, mie, sieMask, MaskedRegMap.NoSideEffect, sieMask),
479    MaskedRegMap(Stvec, stvec),
480    MaskedRegMap(Scounteren, scounteren),
481
482    //--- Supervisor Trap Handling ---
483    MaskedRegMap(Sscratch, sscratch),
484    MaskedRegMap(Sepc, sepc),
485    MaskedRegMap(Scause, scause),
486    MaskedRegMap(Stval, stval),
487    MaskedRegMap(Sip, mip.asUInt, sipMask, MaskedRegMap.Unwritable, sipMask),
488
489    //--- Supervisor Protection and Translation ---
490    MaskedRegMap(Satp, satp, satpMask, MaskedRegMap.NoSideEffect, satpMask),
491
492    //--- Supervisor Custom Read/Write Registers
493    MaskedRegMap(Sbpctl, sbpctl),
494    MaskedRegMap(Spfctl, spfctl),
495    MaskedRegMap(Sdsid, sdsid),
496    MaskedRegMap(Slvpredctl, slvpredctl),
497    MaskedRegMap(Smblockctl, smblockctl),
498    MaskedRegMap(Srnctl, srnctl),
499
500    //--- Machine Information Registers ---
501    MaskedRegMap(Mvendorid, mvendorid, 0.U, MaskedRegMap.Unwritable),
502    MaskedRegMap(Marchid, marchid, 0.U, MaskedRegMap.Unwritable),
503    MaskedRegMap(Mimpid, mimpid, 0.U, MaskedRegMap.Unwritable),
504    MaskedRegMap(Mhartid, mhartid, 0.U, MaskedRegMap.Unwritable),
505
506    //--- Machine Trap Setup ---
507    MaskedRegMap(Mstatus, mstatus, mstatusMask, mstatusUpdateSideEffect, mstatusMask),
508    MaskedRegMap(Misa, misa), // now MXL, EXT is not changeable
509    MaskedRegMap(Medeleg, medeleg, "hf3ff".U),
510    MaskedRegMap(Mideleg, mideleg, "h222".U),
511    MaskedRegMap(Mie, mie),
512    MaskedRegMap(Mtvec, mtvec),
513    MaskedRegMap(Mcounteren, mcounteren),
514
515    //--- Machine Trap Handling ---
516    MaskedRegMap(Mscratch, mscratch),
517    MaskedRegMap(Mepc, mepc),
518    MaskedRegMap(Mcause, mcause),
519    MaskedRegMap(Mtval, mtval),
520    MaskedRegMap(Mip, mip.asUInt, 0.U, MaskedRegMap.Unwritable),
521  )
522
523  // PMP is unimplemented yet
524  val pmpMapping = Map(
525    MaskedRegMap(Pmpcfg0, pmpcfg0),
526    MaskedRegMap(Pmpcfg1, pmpcfg1),
527    MaskedRegMap(Pmpcfg2, pmpcfg2),
528    MaskedRegMap(Pmpcfg3, pmpcfg3),
529    MaskedRegMap(PmpaddrBase + 0, pmpaddr0),
530    MaskedRegMap(PmpaddrBase + 1, pmpaddr1),
531    MaskedRegMap(PmpaddrBase + 2, pmpaddr2),
532    MaskedRegMap(PmpaddrBase + 3, pmpaddr3)
533  )
534
535  var perfCntMapping = Map(
536    MaskedRegMap(Mcountinhibit, mcountinhibit),
537    MaskedRegMap(Mcycle, mcycle),
538    MaskedRegMap(Minstret, minstret),
539  )
540  val MhpmcounterStart = Mhpmcounter3
541  val MhpmeventStart   = Mhpmevent3
542  for (i <- 0 until nrPerfCnts) {
543    perfCntMapping += MaskedRegMap(MhpmcounterStart + i, perfCnts(i))
544    perfCntMapping += MaskedRegMap(MhpmeventStart + i, perfEvents(i))
545  }
546
547  val mapping = basicPrivMapping ++
548                perfCntMapping ++
549                pmpMapping ++
550                emuPerfCntsLoMapping ++
551                (if (XLEN == 32) emuPerfCntsHiMapping else Nil) ++
552                (if (HasFPU) fcsrMapping else Nil)
553
554  val addr = src2(11, 0)
555  val csri = ZeroExt(src2(16, 12), XLEN)
556  val rdata = Wire(UInt(XLEN.W))
557  val wdata = LookupTree(func, List(
558    CSROpType.wrt  -> src1,
559    CSROpType.set  -> (rdata | src1),
560    CSROpType.clr  -> (rdata & (~src1).asUInt()),
561    CSROpType.wrti -> csri,
562    CSROpType.seti -> (rdata | csri),
563    CSROpType.clri -> (rdata & (~csri).asUInt())
564  ))
565
566  val addrInPerfCnt = (addr >= Mcycle.U) && (addr <= Mhpmcounter31.U)
567  csrio.isPerfCnt := addrInPerfCnt
568
569  // satp wen check
570  val satpLegalMode = (wdata.asTypeOf(new SatpStruct).mode===0.U) || (wdata.asTypeOf(new SatpStruct).mode===8.U)
571
572  // general CSR wen check
573  val wen = valid && func =/= CSROpType.jmp && (addr=/=Satp.U || satpLegalMode)
574  val modePermitted = csrAccessPermissionCheck(addr, false.B, priviledgeMode)
575  val perfcntPermitted = perfcntPermissionCheck(addr, priviledgeMode, mcounteren, scounteren)
576  val permitted = Mux(addrInPerfCnt, perfcntPermitted, modePermitted)
577  // Writeable check is ingored.
578  // Currently, write to illegal csr addr will be ignored
579  MaskedRegMap.generate(mapping, addr, rdata, wen && permitted, wdata)
580  io.out.bits.data := rdata
581  io.out.bits.uop := io.in.bits.uop
582  io.out.bits.uop.cf := cfOut
583  io.out.bits.uop.ctrl.flushPipe := flushPipe
584
585  // Fix Mip/Sip write
586  val fixMapping = Map(
587    MaskedRegMap(Mip, mipReg.asUInt, mipFixMask),
588    MaskedRegMap(Sip, mipReg.asUInt, sipMask, MaskedRegMap.NoSideEffect, sipMask)
589  )
590  val rdataDummy = Wire(UInt(XLEN.W))
591  MaskedRegMap.generate(fixMapping, addr, rdataDummy, wen, wdata)
592
593  when (csrio.fpu.fflags.valid) {
594    fcsr := fflags_wfn(update = true)(csrio.fpu.fflags.bits)
595  }
596  // set fs and sd in mstatus
597  when (csrw_dirty_fp_state || csrio.fpu.dirty_fs) {
598    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
599    mstatusNew.fs := "b11".U
600    mstatusNew.sd := true.B
601    mstatus := mstatusNew.asUInt()
602  }
603  csrio.fpu.frm := fcsr.asTypeOf(new FcsrStruct).frm
604
605  // CSR inst decode
606  val isEbreak = addr === privEbreak && func === CSROpType.jmp
607  val isEcall  = addr === privEcall  && func === CSROpType.jmp
608  val isMret   = addr === privMret   && func === CSROpType.jmp
609  val isSret   = addr === privSret   && func === CSROpType.jmp
610  val isUret   = addr === privUret   && func === CSROpType.jmp
611
612  XSDebug(wen, "csr write: pc %x addr %x rdata %x wdata %x func %x\n", cfIn.pc, addr, rdata, wdata, func)
613  XSDebug(wen, "pc %x mstatus %x mideleg %x medeleg %x mode %x\n", cfIn.pc, mstatus, mideleg , medeleg, priviledgeMode)
614
615  // Illegal priviledged operation list
616  val illegalSModeSret = valid && isSret && priviledgeMode === ModeS && mstatusStruct.tsr.asBool
617
618  // Illegal priviledged instruction check
619  val isIllegalAddr = MaskedRegMap.isIllegalAddr(mapping, addr)
620  val isIllegalAccess = !permitted
621  val isIllegalPrivOp = illegalSModeSret
622
623  // def MMUPermissionCheck(ptev: Bool, pteu: Bool): Bool = ptev && !(priviledgeMode === ModeU && !pteu) && !(priviledgeMode === ModeS && pteu && mstatusStruct.sum.asBool)
624  // def MMUPermissionCheckLoad(ptev: Bool, pteu: Bool): Bool = ptev && !(priviledgeMode === ModeU && !pteu) && !(priviledgeMode === ModeS && pteu && mstatusStruct.sum.asBool) && (pter || (mstatusStruct.mxr && ptex))
625  // imem
626  // val imemPtev = true.B
627  // val imemPteu = true.B
628  // val imemPtex = true.B
629  // val imemReq = true.B
630  // val imemPermissionCheckPassed = MMUPermissionCheck(imemPtev, imemPteu)
631  // val hasInstrPageFault = imemReq && !(imemPermissionCheckPassed && imemPtex)
632  // assert(!hasInstrPageFault)
633
634  // dmem
635  // val dmemPtev = true.B
636  // val dmemPteu = true.B
637  // val dmemReq = true.B
638  // val dmemPermissionCheckPassed = MMUPermissionCheck(dmemPtev, dmemPteu)
639  // val dmemIsStore = true.B
640
641  // val hasLoadPageFault  = dmemReq && !dmemIsStore && !(dmemPermissionCheckPassed)
642  // val hasStorePageFault = dmemReq &&  dmemIsStore && !(dmemPermissionCheckPassed)
643  // assert(!hasLoadPageFault)
644  // assert(!hasStorePageFault)
645
646  //TODO: Havn't test if io.dmemMMU.priviledgeMode is correct yet
647  tlbBundle.priv.mxr   := mstatusStruct.mxr.asBool
648  tlbBundle.priv.sum   := mstatusStruct.sum.asBool
649  tlbBundle.priv.imode := priviledgeMode
650  tlbBundle.priv.dmode := Mux(mstatusStruct.mprv.asBool, mstatusStruct.mpp, priviledgeMode)
651
652  // Branch control
653  val retTarget = Wire(UInt(VAddrBits.W))
654  val resetSatp = addr === Satp.U && wen // write to satp will cause the pipeline be flushed
655  flushPipe := resetSatp || (valid && func === CSROpType.jmp && !isEcall)
656
657  retTarget := DontCare
658  // val illegalEret = TODO
659
660  when (valid && isMret) {
661    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
662    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
663    mstatusNew.ie.m := mstatusOld.pie.m
664    priviledgeMode := mstatusOld.mpp
665    mstatusNew.pie.m := true.B
666    mstatusNew.mpp := ModeU
667    mstatusNew.mprv := 0.U
668    mstatus := mstatusNew.asUInt
669    // lr := false.B
670    retTarget := mepc(VAddrBits-1, 0)
671  }
672
673  when (valid && isSret && !illegalSModeSret) {
674    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
675    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
676    mstatusNew.ie.s := mstatusOld.pie.s
677    priviledgeMode := Cat(0.U(1.W), mstatusOld.spp)
678    mstatusNew.pie.s := true.B
679    mstatusNew.spp := ModeU
680    mstatus := mstatusNew.asUInt
681    mstatusNew.mprv := 0.U
682    // lr := false.B
683    retTarget := sepc(VAddrBits-1, 0)
684  }
685
686  when (valid && isUret) {
687    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
688    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
689    // mstatusNew.mpp.m := ModeU //TODO: add mode U
690    mstatusNew.ie.u := mstatusOld.pie.u
691    priviledgeMode := ModeU
692    mstatusNew.pie.u := true.B
693    mstatus := mstatusNew.asUInt
694    retTarget := uepc(VAddrBits-1, 0)
695  }
696
697  io.in.ready := true.B
698  io.out.valid := valid
699
700  val csrExceptionVec = WireInit(cfIn.exceptionVec)
701  csrExceptionVec(breakPoint) := io.in.valid && isEbreak
702  csrExceptionVec(ecallM) := priviledgeMode === ModeM && io.in.valid && isEcall
703  csrExceptionVec(ecallS) := priviledgeMode === ModeS && io.in.valid && isEcall
704  csrExceptionVec(ecallU) := priviledgeMode === ModeU && io.in.valid && isEcall
705  // Trigger an illegal instr exception when:
706  // * unimplemented csr is being read/written
707  // * csr access is illegal
708  csrExceptionVec(illegalInstr) := (isIllegalAddr || isIllegalAccess) && wen
709  cfOut.exceptionVec := csrExceptionVec
710
711  /**
712    * Exception and Intr
713    */
714  val ideleg =  (mideleg & mip.asUInt)
715  def priviledgedEnableDetect(x: Bool): Bool = Mux(x, ((priviledgeMode === ModeS) && mstatusStruct.ie.s) || (priviledgeMode < ModeS),
716    ((priviledgeMode === ModeM) && mstatusStruct.ie.m) || (priviledgeMode < ModeM))
717
718  // send interrupt information to ROQ
719  val intrVecEnable = Wire(Vec(12, Bool()))
720  intrVecEnable.zip(ideleg.asBools).map{case(x,y) => x := priviledgedEnableDetect(y)}
721  val intrVec = mie(11,0) & mip.asUInt & intrVecEnable.asUInt
722  val intrBitSet = intrVec.orR()
723  csrio.interrupt := intrBitSet
724  mipWire.t.m := csrio.externalInterrupt.mtip
725  mipWire.s.m := csrio.externalInterrupt.msip
726  mipWire.e.m := csrio.externalInterrupt.meip
727
728  // interrupts
729  val intrNO = IntPriority.foldRight(0.U)((i: Int, sum: UInt) => Mux(intrVec(i), i.U, sum))
730  val raiseIntr = csrio.exception.valid && csrio.exception.bits.isInterrupt
731  XSDebug(raiseIntr, "interrupt: pc=0x%x, %d\n", csrio.exception.bits.uop.cf.pc, intrNO)
732
733  // exceptions
734  val raiseException = csrio.exception.valid && !csrio.exception.bits.isInterrupt
735  val hasInstrPageFault = csrio.exception.bits.uop.cf.exceptionVec(instrPageFault) && raiseException
736  val hasLoadPageFault = csrio.exception.bits.uop.cf.exceptionVec(loadPageFault) && raiseException
737  val hasStorePageFault = csrio.exception.bits.uop.cf.exceptionVec(storePageFault) && raiseException
738  val hasStoreAddrMisaligned = csrio.exception.bits.uop.cf.exceptionVec(storeAddrMisaligned) && raiseException
739  val hasLoadAddrMisaligned = csrio.exception.bits.uop.cf.exceptionVec(loadAddrMisaligned) && raiseException
740  val hasInstrAccessFault = csrio.exception.bits.uop.cf.exceptionVec(instrAccessFault) && raiseException
741  val hasLoadAccessFault = csrio.exception.bits.uop.cf.exceptionVec(loadAccessFault) && raiseException
742  val hasStoreAccessFault = csrio.exception.bits.uop.cf.exceptionVec(storeAccessFault) && raiseException
743
744  val raiseExceptionVec = csrio.exception.bits.uop.cf.exceptionVec
745  val exceptionNO = ExcPriority.foldRight(0.U)((i: Int, sum: UInt) => Mux(raiseExceptionVec(i), i.U, sum))
746  val causeNO = (raiseIntr << (XLEN-1)).asUInt() | Mux(raiseIntr, intrNO, exceptionNO)
747
748  val raiseExceptionIntr = csrio.exception.valid
749  XSDebug(raiseExceptionIntr, "int/exc: pc %x int (%d):%x exc: (%d):%x\n",
750    csrio.exception.bits.uop.cf.pc, intrNO, intrVec, exceptionNO, raiseExceptionVec.asUInt
751  )
752  XSDebug(raiseExceptionIntr,
753    "pc %x mstatus %x mideleg %x medeleg %x mode %x\n",
754    csrio.exception.bits.uop.cf.pc,
755    mstatus,
756    mideleg,
757    medeleg,
758    priviledgeMode
759  )
760
761  // mtval write logic
762  val memExceptionAddr = SignExt(csrio.memExceptionVAddr, XLEN)
763  when (hasInstrPageFault || hasLoadPageFault || hasStorePageFault) {
764    val tval = Mux(
765      hasInstrPageFault,
766      Mux(
767        csrio.exception.bits.uop.cf.crossPageIPFFix,
768        SignExt(csrio.exception.bits.uop.cf.pc + 2.U, XLEN),
769        SignExt(csrio.exception.bits.uop.cf.pc, XLEN)
770      ),
771      memExceptionAddr
772    )
773    when (priviledgeMode === ModeM) {
774      mtval := tval
775    }.otherwise {
776      stval := tval
777    }
778  }
779
780  when (hasLoadAddrMisaligned || hasStoreAddrMisaligned) {
781    mtval := memExceptionAddr
782  }
783
784  val deleg = Mux(raiseIntr, mideleg , medeleg)
785  // val delegS = ((deleg & (1 << (causeNO & 0xf))) != 0) && (priviledgeMode < ModeM);
786  val delegS = deleg(causeNO(3,0)) && (priviledgeMode < ModeM)
787  val tvalWen = !(hasInstrPageFault || hasLoadPageFault || hasStorePageFault || hasLoadAddrMisaligned || hasStoreAddrMisaligned) || raiseIntr // TODO: need check
788  val isXRet = io.in.valid && func === CSROpType.jmp && !isEcall
789
790  // ctrl block will use theses later for flush
791  val isXRetFlag = RegInit(false.B)
792  val retTargetReg = Reg(retTarget.cloneType)
793  when (io.flushIn) {
794    isXRetFlag := false.B
795  }.elsewhen (isXRet) {
796    isXRetFlag := true.B
797    retTargetReg := retTarget
798  }
799  csrio.isXRet := isXRetFlag
800  csrio.trapTarget := Mux(isXRetFlag,
801    retTargetReg,
802    Mux(delegS, stvec, mtvec)(VAddrBits-1, 0)
803  )
804
805  when (raiseExceptionIntr) {
806    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
807    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
808
809    when (delegS) {
810      scause := causeNO
811      sepc := SignExt(csrio.exception.bits.uop.cf.pc, XLEN)
812      mstatusNew.spp := priviledgeMode
813      mstatusNew.pie.s := mstatusOld.ie.s
814      mstatusNew.ie.s := false.B
815      priviledgeMode := ModeS
816      when (tvalWen) { stval := 0.U }
817    }.otherwise {
818      mcause := causeNO
819      mepc := SignExt(csrio.exception.bits.uop.cf.pc, XLEN)
820      mstatusNew.mpp := priviledgeMode
821      mstatusNew.pie.m := mstatusOld.ie.m
822      mstatusNew.ie.m := false.B
823      priviledgeMode := ModeM
824      when (tvalWen) { mtval := 0.U }
825    }
826
827    mstatus := mstatusNew.asUInt
828  }
829
830  XSDebug(raiseExceptionIntr && delegS, "sepc is writen!!! pc:%x\n", cfIn.pc)
831
832  def readWithScala(addr: Int): UInt = mapping(addr)._1
833
834  val difftestIntrNO = Mux(raiseIntr, causeNO, 0.U)
835
836  if (!env.FPGAPlatform) {
837    difftestIO.intrNO := RegNext(difftestIntrNO)
838    difftestIO.cause := RegNext(Mux(csrio.exception.valid, causeNO, 0.U))
839    difftestIO.priviledgeMode := priviledgeMode
840    difftestIO.mstatus := mstatus
841    difftestIO.sstatus := mstatus & sstatusRmask
842    difftestIO.mepc := mepc
843    difftestIO.sepc := sepc
844    difftestIO.mtval:= mtval
845    difftestIO.stval:= stval
846    difftestIO.mtvec := mtvec
847    difftestIO.stvec := stvec
848    difftestIO.mcause := mcause
849    difftestIO.scause := scause
850    difftestIO.satp := satp
851    difftestIO.mip := mipReg
852    difftestIO.mie := mie
853    difftestIO.mscratch := mscratch
854    difftestIO.sscratch := sscratch
855    difftestIO.mideleg := mideleg
856    difftestIO.medeleg := medeleg
857  }
858}
859