xref: /XiangShan/src/main/scala/xiangshan/backend/fu/CSR.scala (revision 69b52b93fd0164ab9b16aec21eaf16eeb8fa614d)
1package xiangshan.backend.fu
2
3import chisel3._
4import chisel3.ExcitingUtils.{ConnectionType, Debug}
5import chisel3.util._
6import utils._
7import xiangshan._
8import xiangshan.backend._
9import xiangshan.backend.fu.util._
10
11trait HasExceptionNO {
12  def instrAddrMisaligned = 0
13  def instrAccessFault    = 1
14  def illegalInstr        = 2
15  def breakPoint          = 3
16  def loadAddrMisaligned  = 4
17  def loadAccessFault     = 5
18  def storeAddrMisaligned = 6
19  def storeAccessFault    = 7
20  def ecallU              = 8
21  def ecallS              = 9
22  def ecallM              = 11
23  def instrPageFault      = 12
24  def loadPageFault       = 13
25  def storePageFault      = 15
26
27  val ExcPriority = Seq(
28    breakPoint, // TODO: different BP has different priority
29    instrPageFault,
30    instrAccessFault,
31    illegalInstr,
32    instrAddrMisaligned,
33    ecallM, ecallS, ecallU,
34    storePageFault,
35    loadPageFault,
36    storeAccessFault,
37    loadAccessFault,
38    storeAddrMisaligned,
39    loadAddrMisaligned
40  )
41  val frontendSet = List(
42    // instrAddrMisaligned,
43    instrAccessFault,
44    illegalInstr,
45    instrPageFault
46  )
47  val csrSet = List(
48    illegalInstr,
49    breakPoint,
50    ecallU,
51    ecallS,
52    ecallM
53  )
54  val loadUnitSet = List(
55    loadAddrMisaligned,
56    loadAccessFault,
57    loadPageFault
58  )
59  val storeUnitSet = List(
60    storeAddrMisaligned,
61    storeAccessFault,
62    storePageFault
63  )
64  val atomicsUnitSet = (loadUnitSet ++ storeUnitSet).distinct
65  val allPossibleSet = (frontendSet ++ csrSet ++ loadUnitSet ++ storeUnitSet).distinct
66  val csrWbCount = (0 until 16).map(i => if (csrSet.contains(i)) 1 else 0)
67  val loadWbCount = (0 until 16).map(i => if (loadUnitSet.contains(i)) 1 else 0)
68  val storeWbCount = (0 until 16).map(i => if (storeUnitSet.contains(i)) 1 else 0)
69  val atomicsWbCount = (0 until 16).map(i => if (atomicsUnitSet.contains(i)) 1 else 0)
70  val writebackCount = (0 until 16).map(i => csrWbCount(i) + atomicsWbCount(i) + loadWbCount(i) + 2 * storeWbCount(i))
71  def partialSelect(vec: Vec[Bool], select: Seq[Int], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] = {
72    if (dontCareBits) {
73      val new_vec = Wire(ExceptionVec())
74      new_vec := DontCare
75      select.map(i => new_vec(i) := vec(i))
76      return new_vec
77    }
78    else if (falseBits) {
79      val new_vec = Wire(ExceptionVec())
80      new_vec.map(_ := false.B)
81      select.map(i => new_vec(i) := vec(i))
82      return new_vec
83    }
84    else {
85      val new_vec = Wire(Vec(select.length, Bool()))
86      select.zipWithIndex.map{ case(s, i) => new_vec(i) := vec(s) }
87      return new_vec
88    }
89  }
90  def selectFrontend(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
91    partialSelect(vec, frontendSet, dontCareBits, falseBits)
92  def selectCSR(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
93    partialSelect(vec, csrSet, dontCareBits, falseBits)
94  def selectLoad(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
95    partialSelect(vec, loadUnitSet, dontCareBits, falseBits)
96  def selectStore(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
97    partialSelect(vec, storeUnitSet, dontCareBits, falseBits)
98  def selectAtomics(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
99    partialSelect(vec, atomicsUnitSet, dontCareBits, falseBits)
100  def selectAll(vec: Vec[Bool], dontCareBits: Boolean = true, falseBits: Boolean = false): Vec[Bool] =
101    partialSelect(vec, allPossibleSet, dontCareBits, falseBits)
102}
103
104class FpuCsrIO extends XSBundle {
105  val fflags = Output(Valid(UInt(5.W)))
106  val isIllegal = Output(Bool())
107  val dirty_fs = Output(Bool())
108  val frm = Input(UInt(3.W))
109}
110
111
112class PerfCounterIO extends XSBundle {
113  val retiredInstr = Input(UInt(3.W))
114  val value = Input(UInt(XLEN.W))
115}
116
117class CSR extends FunctionUnit with HasCSRConst
118{
119  val csrio = IO(new Bundle {
120    val hartId = Input(UInt(64.W))
121    // output (for func === CSROpType.jmp)
122    val perf = new PerfCounterIO
123    val isPerfCnt = Output(Bool())
124    // to FPU
125    val fpu = Flipped(new FpuCsrIO)
126    // from rob
127    val exception = Flipped(ValidIO(new ExceptionInfo))
128    // to ROB
129    val isXRet = Output(Bool())
130    val trapTarget = Output(UInt(VAddrBits.W))
131    val interrupt = Output(Bool())
132    // from LSQ
133    val memExceptionVAddr = Input(UInt(VAddrBits.W))
134    // from outside cpu,externalInterrupt
135    val externalInterrupt = new ExternalInterruptIO
136    // TLB
137    val tlb = Output(new TlbCsrBundle)
138  })
139  val difftestIO = IO(new Bundle() {
140    val intrNO = Output(UInt(64.W))
141    val cause = Output(UInt(64.W))
142    val priviledgeMode = Output(UInt(2.W))
143    val mstatus = Output(UInt(64.W))
144    val sstatus = Output(UInt(64.W))
145    val mepc = Output(UInt(64.W))
146    val sepc = Output(UInt(64.W))
147    val mtval = Output(UInt(64.W))
148    val stval = Output(UInt(64.W))
149    val mtvec = Output(UInt(64.W))
150    val stvec = Output(UInt(64.W))
151    val mcause = Output(UInt(64.W))
152    val scause = Output(UInt(64.W))
153    val satp = Output(UInt(64.W))
154    val mip = Output(UInt(64.W))
155    val mie = Output(UInt(64.W))
156    val mscratch = Output(UInt(64.W))
157    val sscratch = Output(UInt(64.W))
158    val mideleg = Output(UInt(64.W))
159    val medeleg = Output(UInt(64.W))
160  })
161  difftestIO <> DontCare
162
163  val cfIn = io.in.bits.uop.cf
164  val cfOut = Wire(new CtrlFlow)
165  cfOut := cfIn
166  val flushPipe = Wire(Bool())
167
168  val (valid, src1, src2, func) = (
169    io.in.valid,
170    io.in.bits.src(0),
171    io.in.bits.uop.ctrl.imm,
172    io.in.bits.uop.ctrl.fuOpType
173  )
174
175  // CSR define
176
177  class Priv extends Bundle {
178    val m = Output(Bool())
179    val h = Output(Bool())
180    val s = Output(Bool())
181    val u = Output(Bool())
182  }
183
184  val csrNotImplemented = RegInit(UInt(XLEN.W), 0.U)
185
186  class MstatusStruct extends Bundle {
187    val sd = Output(UInt(1.W))
188
189    val pad1 = if (XLEN == 64) Output(UInt(27.W)) else null
190    val sxl  = if (XLEN == 64) Output(UInt(2.W))  else null
191    val uxl  = if (XLEN == 64) Output(UInt(2.W))  else null
192    val pad0 = if (XLEN == 64) Output(UInt(9.W))  else Output(UInt(8.W))
193
194    val tsr = Output(UInt(1.W))
195    val tw = Output(UInt(1.W))
196    val tvm = Output(UInt(1.W))
197    val mxr = Output(UInt(1.W))
198    val sum = Output(UInt(1.W))
199    val mprv = Output(UInt(1.W))
200    val xs = Output(UInt(2.W))
201    val fs = Output(UInt(2.W))
202    val mpp = Output(UInt(2.W))
203    val hpp = Output(UInt(2.W))
204    val spp = Output(UInt(1.W))
205    val pie = new Priv
206    val ie = new Priv
207    assert(this.getWidth == XLEN)
208  }
209
210  class SatpStruct extends Bundle {
211    val mode = UInt(4.W)
212    val asid = UInt(16.W)
213    val ppn  = UInt(44.W)
214  }
215
216  class Interrupt extends Bundle {
217    val e = new Priv
218    val t = new Priv
219    val s = new Priv
220  }
221
222  // Machine-Level CSRs
223
224  val mtvec = RegInit(UInt(XLEN.W), 0.U)
225  val mcounteren = RegInit(UInt(XLEN.W), 0.U)
226  val mcause = RegInit(UInt(XLEN.W), 0.U)
227  val mtval = RegInit(UInt(XLEN.W), 0.U)
228  val mepc = Reg(UInt(XLEN.W))
229
230  val mie = RegInit(0.U(XLEN.W))
231  val mipWire = WireInit(0.U.asTypeOf(new Interrupt))
232  val mipReg  = RegInit(0.U.asTypeOf(new Interrupt).asUInt)
233  val mipFixMask = GenMask(9) | GenMask(5) | GenMask(1)
234  val mip = (mipWire.asUInt | mipReg).asTypeOf(new Interrupt)
235
236  def getMisaMxl(mxl: Int): UInt = {mxl.U << (XLEN-2)}.asUInt()
237  def getMisaExt(ext: Char): UInt = {1.U << (ext.toInt - 'a'.toInt)}.asUInt()
238  var extList = List('a', 's', 'i', 'u')
239  if (HasMExtension) { extList = extList :+ 'm' }
240  if (HasCExtension) { extList = extList :+ 'c' }
241  if (HasFPU) { extList = extList ++ List('f', 'd') }
242  val misaInitVal = getMisaMxl(2) | extList.foldLeft(0.U)((sum, i) => sum | getMisaExt(i)) //"h8000000000141105".U
243  val misa = RegInit(UInt(XLEN.W), misaInitVal)
244
245  // MXL = 2          | 0 | EXT = b 00 0000 0100 0001 0001 0000 0101
246  // (XLEN-1, XLEN-2) |   |(25, 0)  ZY XWVU TSRQ PONM LKJI HGFE DCBA
247
248  val mvendorid = RegInit(UInt(XLEN.W), 0.U) // this is a non-commercial implementation
249  val marchid = RegInit(UInt(XLEN.W), 0.U) // return 0 to indicate the field is not implemented
250  val mimpid = RegInit(UInt(XLEN.W), 0.U) // provides a unique encoding of the version of the processor implementation
251  val mhartid = RegInit(UInt(XLEN.W), csrio.hartId) // the hardware thread running the code
252  val mstatus = RegInit(UInt(XLEN.W), 0.U)
253
254  // mstatus Value Table
255  // | sd   |
256  // | pad1 |
257  // | sxl  | hardlinked to 10, use 00 to pass xv6 test
258  // | uxl  | hardlinked to 00
259  // | pad0 |
260  // | tsr  |
261  // | tw   |
262  // | tvm  |
263  // | mxr  |
264  // | sum  |
265  // | mprv |
266  // | xs   | 00 |
267  // | fs   | 00 |
268  // | mpp  | 00 |
269  // | hpp  | 00 |
270  // | spp  | 0 |
271  // | pie  | 0000 | pie.h is used as UBE
272  // | ie   | 0000 | uie hardlinked to 0, as N ext is not implemented
273
274  val mstatusStruct = mstatus.asTypeOf(new MstatusStruct)
275  def mstatusUpdateSideEffect(mstatus: UInt): UInt = {
276    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
277    val mstatusNew = Cat(mstatusOld.xs === "b11".U || mstatusOld.fs === "b11".U, mstatus(XLEN-2, 0))
278    mstatusNew
279  }
280
281  val mstatusMask = (~ZeroExt((
282    GenMask(XLEN-2, 38) | GenMask(31, 23) | GenMask(10, 9) | GenMask(2) |
283    GenMask(37) | // MBE
284    GenMask(36) | // SBE
285    GenMask(6)    // UBE
286  ), 64)).asUInt()
287
288  val medeleg = RegInit(UInt(XLEN.W), 0.U)
289  val mideleg = RegInit(UInt(XLEN.W), 0.U)
290  val mscratch = RegInit(UInt(XLEN.W), 0.U)
291
292  val pmpcfg0 = RegInit(UInt(XLEN.W), 0.U)
293  val pmpcfg1 = RegInit(UInt(XLEN.W), 0.U)
294  val pmpcfg2 = RegInit(UInt(XLEN.W), 0.U)
295  val pmpcfg3 = RegInit(UInt(XLEN.W), 0.U)
296  val pmpaddr0 = RegInit(UInt(XLEN.W), 0.U)
297  val pmpaddr1 = RegInit(UInt(XLEN.W), 0.U)
298  val pmpaddr2 = RegInit(UInt(XLEN.W), 0.U)
299  val pmpaddr3 = RegInit(UInt(XLEN.W), 0.U)
300
301  // Superviser-Level CSRs
302
303  // val sstatus = RegInit(UInt(XLEN.W), "h00000000".U)
304  val sstatusWmask = "hc6122".U
305  // Sstatus Write Mask
306  // -------------------------------------------------------
307  //    19           9   5     2
308  // 0  1100 0000 0001 0010 0010
309  // 0  c    0    1    2    2
310  // -------------------------------------------------------
311  val sstatusRmask = sstatusWmask | "h8000000300018000".U
312  // Sstatus Read Mask = (SSTATUS_WMASK | (0xf << 13) | (1ull << 63) | (3ull << 32))
313  val stvec = RegInit(UInt(XLEN.W), 0.U)
314  // val sie = RegInit(0.U(XLEN.W))
315  val sieMask = "h222".U & mideleg
316  val sipMask  = "h222".U & mideleg
317  val satp = if(EnbaleTlbDebug) RegInit(UInt(XLEN.W), "h8000000000087fbe".U) else RegInit(0.U(XLEN.W))
318  // val satp = RegInit(UInt(XLEN.W), "h8000000000087fbe".U) // only use for tlb naive debug
319  val satpMask = "h80000fffffffffff".U // disable asid, mode can only be 8 / 0
320  val sepc = RegInit(UInt(XLEN.W), 0.U)
321  val scause = RegInit(UInt(XLEN.W), 0.U)
322  val stval = Reg(UInt(XLEN.W))
323  val sscratch = RegInit(UInt(XLEN.W), 0.U)
324  val scounteren = RegInit(UInt(XLEN.W), 0.U)
325
326  val tlbBundle = Wire(new TlbCsrBundle)
327  tlbBundle.satp := satp.asTypeOf(new SatpStruct)
328  csrio.tlb := tlbBundle
329
330  // User-Level CSRs
331  val uepc = Reg(UInt(XLEN.W))
332
333  // fcsr
334  class FcsrStruct extends Bundle {
335    val reserved = UInt((XLEN-3-5).W)
336    val frm = UInt(3.W)
337    val fflags = UInt(5.W)
338    assert(this.getWidth == XLEN)
339  }
340  val fcsr = RegInit(0.U(XLEN.W))
341  // set mstatus->sd and mstatus->fs when true
342  val csrw_dirty_fp_state = WireInit(false.B)
343
344  def frm_wfn(wdata: UInt): UInt = {
345    val fcsrOld = WireInit(fcsr.asTypeOf(new FcsrStruct))
346    csrw_dirty_fp_state := true.B
347    fcsrOld.frm := wdata(2,0)
348    fcsrOld.asUInt()
349  }
350  def frm_rfn(rdata: UInt): UInt = rdata(7,5)
351
352  def fflags_wfn(update: Boolean)(wdata: UInt): UInt = {
353    val fcsrOld = fcsr.asTypeOf(new FcsrStruct)
354    val fcsrNew = WireInit(fcsrOld)
355    csrw_dirty_fp_state := true.B
356    if (update) {
357      fcsrNew.fflags := wdata(4,0) | fcsrOld.fflags
358    } else {
359      fcsrNew.fflags := wdata(4,0)
360    }
361    fcsrNew.asUInt()
362  }
363  def fflags_rfn(rdata:UInt): UInt = rdata(4,0)
364
365  def fcsr_wfn(wdata: UInt): UInt = {
366    val fcsrOld = WireInit(fcsr.asTypeOf(new FcsrStruct))
367    csrw_dirty_fp_state := true.B
368    Cat(fcsrOld.reserved, wdata.asTypeOf(fcsrOld).frm, wdata.asTypeOf(fcsrOld).fflags)
369  }
370
371  val fcsrMapping = Map(
372    MaskedRegMap(Fflags, fcsr, wfn = fflags_wfn(update = false), rfn = fflags_rfn),
373    MaskedRegMap(Frm, fcsr, wfn = frm_wfn, rfn = frm_rfn),
374    MaskedRegMap(Fcsr, fcsr, wfn = fcsr_wfn)
375  )
376
377  // Atom LR/SC Control Bits
378  //  val setLr = WireInit(Bool(), false.B)
379  //  val setLrVal = WireInit(Bool(), false.B)
380  //  val setLrAddr = WireInit(UInt(AddrBits.W), DontCare) //TODO : need check
381  //  val lr = RegInit(Bool(), false.B)
382  //  val lrAddr = RegInit(UInt(AddrBits.W), 0.U)
383  //
384  //  when (setLr) {
385  //    lr := setLrVal
386  //    lrAddr := setLrAddr
387  //  }
388
389  // Hart Priviledge Mode
390  val priviledgeMode = RegInit(UInt(2.W), ModeM)
391
392  // Emu perfcnt
393  val hasEmuPerfCnt = !env.FPGAPlatform
394  val nrEmuPerfCnts = if (hasEmuPerfCnt) 0x80 else 0x3
395
396  val emuPerfCnts    = List.fill(nrEmuPerfCnts)(RegInit(0.U(XLEN.W)))
397  val emuPerfCntCond = List.fill(nrEmuPerfCnts)(WireInit(false.B))
398  (emuPerfCnts zip emuPerfCntCond).map { case (c, e) => when (e) { c := c + 1.U } }
399
400  val emuPerfCntsLoMapping = (0 until nrEmuPerfCnts).map(i => MaskedRegMap(0x1000 + i, emuPerfCnts(i)))
401  val emuPerfCntsHiMapping = (0 until nrEmuPerfCnts).map(i => MaskedRegMap(0x1080 + i, emuPerfCnts(i)(63, 32)))
402  println(s"CSR: hasEmuPerfCnt:${hasEmuPerfCnt}")
403
404  // Perf Counter
405  val nrPerfCnts = 29  // 3...31
406  val perfCnts   = List.fill(nrPerfCnts)(RegInit(0.U(XLEN.W)))
407  val perfEvents = List.fill(nrPerfCnts)(RegInit(0.U(XLEN.W)))
408  val mcountinhibit = RegInit(0.U(XLEN.W))
409  val mcycle = RegInit(0.U(XLEN.W))
410  mcycle := mcycle + 1.U
411  val minstret = RegInit(0.U(XLEN.W))
412  minstret := minstret + RegNext(csrio.perf.retiredInstr)
413
414  // CSR reg map
415  val basicPrivMapping = Map(
416
417    //--- User Trap Setup ---
418    // MaskedRegMap(Ustatus, ustatus),
419    // MaskedRegMap(Uie, uie, 0.U, MaskedRegMap.Unwritable),
420    // MaskedRegMap(Utvec, utvec),
421
422    //--- User Trap Handling ---
423    // MaskedRegMap(Uscratch, uscratch),
424    // MaskedRegMap(Uepc, uepc),
425    // MaskedRegMap(Ucause, ucause),
426    // MaskedRegMap(Utval, utval),
427    // MaskedRegMap(Uip, uip),
428
429    //--- User Counter/Timers ---
430    // MaskedRegMap(Cycle, cycle),
431    // MaskedRegMap(Time, time),
432    // MaskedRegMap(Instret, instret),
433
434    //--- Supervisor Trap Setup ---
435    MaskedRegMap(Sstatus, mstatus, sstatusWmask, mstatusUpdateSideEffect, sstatusRmask),
436    // MaskedRegMap(Sedeleg, Sedeleg),
437    // MaskedRegMap(Sideleg, Sideleg),
438    MaskedRegMap(Sie, mie, sieMask, MaskedRegMap.NoSideEffect, sieMask),
439    MaskedRegMap(Stvec, stvec),
440    MaskedRegMap(Scounteren, scounteren),
441
442    //--- Supervisor Trap Handling ---
443    MaskedRegMap(Sscratch, sscratch),
444    MaskedRegMap(Sepc, sepc),
445    MaskedRegMap(Scause, scause),
446    MaskedRegMap(Stval, stval),
447    MaskedRegMap(Sip, mip.asUInt, sipMask, MaskedRegMap.Unwritable, sipMask),
448
449    //--- Supervisor Protection and Translation ---
450    MaskedRegMap(Satp, satp, satpMask, MaskedRegMap.NoSideEffect, satpMask),
451
452    //--- Machine Information Registers ---
453    MaskedRegMap(Mvendorid, mvendorid, 0.U, MaskedRegMap.Unwritable),
454    MaskedRegMap(Marchid, marchid, 0.U, MaskedRegMap.Unwritable),
455    MaskedRegMap(Mimpid, mimpid, 0.U, MaskedRegMap.Unwritable),
456    MaskedRegMap(Mhartid, mhartid, 0.U, MaskedRegMap.Unwritable),
457
458    //--- Machine Trap Setup ---
459    MaskedRegMap(Mstatus, mstatus, mstatusMask, mstatusUpdateSideEffect, mstatusMask),
460    MaskedRegMap(Misa, misa), // now MXL, EXT is not changeable
461    MaskedRegMap(Medeleg, medeleg, "hf3ff".U),
462    MaskedRegMap(Mideleg, mideleg, "h222".U),
463    MaskedRegMap(Mie, mie),
464    MaskedRegMap(Mtvec, mtvec),
465    MaskedRegMap(Mcounteren, mcounteren),
466
467    //--- Machine Trap Handling ---
468    MaskedRegMap(Mscratch, mscratch),
469    MaskedRegMap(Mepc, mepc),
470    MaskedRegMap(Mcause, mcause),
471    MaskedRegMap(Mtval, mtval),
472    MaskedRegMap(Mip, mip.asUInt, 0.U, MaskedRegMap.Unwritable),
473  )
474
475  // PMP is unimplemented yet
476  val pmpMapping = Map(
477    MaskedRegMap(Pmpcfg0, pmpcfg0),
478    MaskedRegMap(Pmpcfg1, pmpcfg1),
479    MaskedRegMap(Pmpcfg2, pmpcfg2),
480    MaskedRegMap(Pmpcfg3, pmpcfg3),
481    MaskedRegMap(PmpaddrBase + 0, pmpaddr0),
482    MaskedRegMap(PmpaddrBase + 1, pmpaddr1),
483    MaskedRegMap(PmpaddrBase + 2, pmpaddr2),
484    MaskedRegMap(PmpaddrBase + 3, pmpaddr3)
485  )
486
487  var perfCntMapping = Map(
488    MaskedRegMap(Mcountinhibit, mcountinhibit),
489    MaskedRegMap(Mcycle, mcycle),
490    MaskedRegMap(Minstret, minstret),
491  )
492  val MhpmcounterStart = Mhpmcounter3
493  val MhpmeventStart   = Mhpmevent3
494  for (i <- 0 until nrPerfCnts) {
495    perfCntMapping += MaskedRegMap(MhpmcounterStart + i, perfCnts(i))
496    perfCntMapping += MaskedRegMap(MhpmeventStart + i, perfEvents(i))
497  }
498
499  val mapping = basicPrivMapping ++
500                perfCntMapping ++
501                pmpMapping ++
502                emuPerfCntsLoMapping ++
503                (if (XLEN == 32) emuPerfCntsHiMapping else Nil) ++
504                (if (HasFPU) fcsrMapping else Nil)
505
506  val addr = src2(11, 0)
507  val csri = ZeroExt(src2(16, 12), XLEN)
508  val rdata = Wire(UInt(XLEN.W))
509  val wdata = LookupTree(func, List(
510    CSROpType.wrt  -> src1,
511    CSROpType.set  -> (rdata | src1),
512    CSROpType.clr  -> (rdata & (~src1).asUInt()),
513    CSROpType.wrti -> csri,
514    CSROpType.seti -> (rdata | csri),
515    CSROpType.clri -> (rdata & (~csri).asUInt())
516  ))
517
518  val addrInPerfCnt = (addr >= Mcycle.U) && (addr <= Mhpmcounter31.U)
519  csrio.isPerfCnt := addrInPerfCnt
520
521  // satp wen check
522  val satpLegalMode = (wdata.asTypeOf(new SatpStruct).mode===0.U) || (wdata.asTypeOf(new SatpStruct).mode===8.U)
523
524  // general CSR wen check
525  val wen = valid && func =/= CSROpType.jmp && (addr=/=Satp.U || satpLegalMode)
526  val modePermitted = csrAccessPermissionCheck(addr, false.B, priviledgeMode)
527  val perfcntPermitted = perfcntPermissionCheck(addr, priviledgeMode, mcounteren, scounteren)
528  val permitted = Mux(addrInPerfCnt, perfcntPermitted, modePermitted)
529  // Writeable check is ingored.
530  // Currently, write to illegal csr addr will be ignored
531  MaskedRegMap.generate(mapping, addr, rdata, wen && permitted, wdata)
532  io.out.bits.data := rdata
533  io.out.bits.uop := io.in.bits.uop
534  io.out.bits.uop.cf := cfOut
535  io.out.bits.uop.ctrl.flushPipe := flushPipe
536
537  // Fix Mip/Sip write
538  val fixMapping = Map(
539    MaskedRegMap(Mip, mipReg.asUInt, mipFixMask),
540    MaskedRegMap(Sip, mipReg.asUInt, sipMask, MaskedRegMap.NoSideEffect, sipMask)
541  )
542  val rdataDummy = Wire(UInt(XLEN.W))
543  MaskedRegMap.generate(fixMapping, addr, rdataDummy, wen, wdata)
544
545  when (csrio.fpu.fflags.valid) {
546    fcsr := fflags_wfn(update = true)(csrio.fpu.fflags.bits)
547  }
548  // set fs and sd in mstatus
549  when (csrw_dirty_fp_state || csrio.fpu.dirty_fs) {
550    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
551    mstatusNew.fs := "b11".U
552    mstatusNew.sd := true.B
553    mstatus := mstatusNew.asUInt()
554  }
555  csrio.fpu.frm := fcsr.asTypeOf(new FcsrStruct).frm
556
557  // CSR inst decode
558  val isEbreak = addr === privEbreak && func === CSROpType.jmp
559  val isEcall  = addr === privEcall  && func === CSROpType.jmp
560  val isMret   = addr === privMret   && func === CSROpType.jmp
561  val isSret   = addr === privSret   && func === CSROpType.jmp
562  val isUret   = addr === privUret   && func === CSROpType.jmp
563
564  XSDebug(wen, "csr write: pc %x addr %x rdata %x wdata %x func %x\n", cfIn.pc, addr, rdata, wdata, func)
565  XSDebug(wen, "pc %x mstatus %x mideleg %x medeleg %x mode %x\n", cfIn.pc, mstatus, mideleg , medeleg, priviledgeMode)
566
567  // Illegal priviledged operation list
568  val illegalSModeSret = valid && isSret && priviledgeMode === ModeS && mstatusStruct.tsr.asBool
569
570  // Illegal priviledged instruction check
571  val isIllegalAddr = MaskedRegMap.isIllegalAddr(mapping, addr)
572  val isIllegalAccess = !permitted
573  val isIllegalPrivOp = illegalSModeSret
574
575  // def MMUPermissionCheck(ptev: Bool, pteu: Bool): Bool = ptev && !(priviledgeMode === ModeU && !pteu) && !(priviledgeMode === ModeS && pteu && mstatusStruct.sum.asBool)
576  // def MMUPermissionCheckLoad(ptev: Bool, pteu: Bool): Bool = ptev && !(priviledgeMode === ModeU && !pteu) && !(priviledgeMode === ModeS && pteu && mstatusStruct.sum.asBool) && (pter || (mstatusStruct.mxr && ptex))
577  // imem
578  // val imemPtev = true.B
579  // val imemPteu = true.B
580  // val imemPtex = true.B
581  // val imemReq = true.B
582  // val imemPermissionCheckPassed = MMUPermissionCheck(imemPtev, imemPteu)
583  // val hasInstrPageFault = imemReq && !(imemPermissionCheckPassed && imemPtex)
584  // assert(!hasInstrPageFault)
585
586  // dmem
587  // val dmemPtev = true.B
588  // val dmemPteu = true.B
589  // val dmemReq = true.B
590  // val dmemPermissionCheckPassed = MMUPermissionCheck(dmemPtev, dmemPteu)
591  // val dmemIsStore = true.B
592
593  // val hasLoadPageFault  = dmemReq && !dmemIsStore && !(dmemPermissionCheckPassed)
594  // val hasStorePageFault = dmemReq &&  dmemIsStore && !(dmemPermissionCheckPassed)
595  // assert(!hasLoadPageFault)
596  // assert(!hasStorePageFault)
597
598  //TODO: Havn't test if io.dmemMMU.priviledgeMode is correct yet
599  tlbBundle.priv.mxr   := mstatusStruct.mxr.asBool
600  tlbBundle.priv.sum   := mstatusStruct.sum.asBool
601  tlbBundle.priv.imode := priviledgeMode
602  tlbBundle.priv.dmode := Mux(mstatusStruct.mprv.asBool, mstatusStruct.mpp, priviledgeMode)
603
604  // Branch control
605  val retTarget = Wire(UInt(VAddrBits.W))
606  val resetSatp = addr === Satp.U && wen // write to satp will cause the pipeline be flushed
607  flushPipe := resetSatp || (valid && func === CSROpType.jmp && !isEcall)
608
609  retTarget := DontCare
610  // val illegalEret = TODO
611
612  when (valid && isMret) {
613    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
614    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
615    mstatusNew.ie.m := mstatusOld.pie.m
616    priviledgeMode := mstatusOld.mpp
617    mstatusNew.pie.m := true.B
618    mstatusNew.mpp := ModeU
619    mstatusNew.mprv := 0.U
620    mstatus := mstatusNew.asUInt
621    // lr := false.B
622    retTarget := mepc(VAddrBits-1, 0)
623  }
624
625  when (valid && isSret && !illegalSModeSret) {
626    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
627    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
628    mstatusNew.ie.s := mstatusOld.pie.s
629    priviledgeMode := Cat(0.U(1.W), mstatusOld.spp)
630    mstatusNew.pie.s := true.B
631    mstatusNew.spp := ModeU
632    mstatus := mstatusNew.asUInt
633    mstatusNew.mprv := 0.U
634    // lr := false.B
635    retTarget := sepc(VAddrBits-1, 0)
636  }
637
638  when (valid && isUret) {
639    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
640    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
641    // mstatusNew.mpp.m := ModeU //TODO: add mode U
642    mstatusNew.ie.u := mstatusOld.pie.u
643    priviledgeMode := ModeU
644    mstatusNew.pie.u := true.B
645    mstatus := mstatusNew.asUInt
646    retTarget := uepc(VAddrBits-1, 0)
647  }
648
649  io.in.ready := true.B
650  io.out.valid := valid
651
652  val csrExceptionVec = WireInit(cfIn.exceptionVec)
653  csrExceptionVec(breakPoint) := io.in.valid && isEbreak
654  csrExceptionVec(ecallM) := priviledgeMode === ModeM && io.in.valid && isEcall
655  csrExceptionVec(ecallS) := priviledgeMode === ModeS && io.in.valid && isEcall
656  csrExceptionVec(ecallU) := priviledgeMode === ModeU && io.in.valid && isEcall
657  // Trigger an illegal instr exception when:
658  // * unimplemented csr is being read/written
659  // * csr access is illegal
660  csrExceptionVec(illegalInstr) := (isIllegalAddr || isIllegalAccess) && wen
661  cfOut.exceptionVec := csrExceptionVec
662
663  /**
664    * Exception and Intr
665    */
666  val ideleg =  (mideleg & mip.asUInt)
667  def priviledgedEnableDetect(x: Bool): Bool = Mux(x, ((priviledgeMode === ModeS) && mstatusStruct.ie.s) || (priviledgeMode < ModeS),
668    ((priviledgeMode === ModeM) && mstatusStruct.ie.m) || (priviledgeMode < ModeM))
669
670  // send interrupt information to ROQ
671  val intrVecEnable = Wire(Vec(12, Bool()))
672  intrVecEnable.zip(ideleg.asBools).map{case(x,y) => x := priviledgedEnableDetect(y)}
673  val intrVec = mie(11,0) & mip.asUInt & intrVecEnable.asUInt
674  val intrBitSet = intrVec.orR()
675  csrio.interrupt := intrBitSet
676  mipWire.t.m := csrio.externalInterrupt.mtip
677  mipWire.s.m := csrio.externalInterrupt.msip
678  mipWire.e.m := csrio.externalInterrupt.meip
679
680  // interrupts
681  val intrNO = IntPriority.foldRight(0.U)((i: Int, sum: UInt) => Mux(intrVec(i), i.U, sum))
682  val raiseIntr = csrio.exception.valid && csrio.exception.bits.isInterrupt
683  XSDebug(raiseIntr, "interrupt: pc=0x%x, %d\n", csrio.exception.bits.uop.cf.pc, intrNO)
684
685  // exceptions
686  val raiseException = csrio.exception.valid && !csrio.exception.bits.isInterrupt
687  val hasInstrPageFault = csrio.exception.bits.uop.cf.exceptionVec(instrPageFault) && raiseException
688  val hasLoadPageFault = csrio.exception.bits.uop.cf.exceptionVec(loadPageFault) && raiseException
689  val hasStorePageFault = csrio.exception.bits.uop.cf.exceptionVec(storePageFault) && raiseException
690  val hasStoreAddrMisaligned = csrio.exception.bits.uop.cf.exceptionVec(storeAddrMisaligned) && raiseException
691  val hasLoadAddrMisaligned = csrio.exception.bits.uop.cf.exceptionVec(loadAddrMisaligned) && raiseException
692  val hasInstrAccessFault = csrio.exception.bits.uop.cf.exceptionVec(instrAccessFault) && raiseException
693  val hasLoadAccessFault = csrio.exception.bits.uop.cf.exceptionVec(loadAccessFault) && raiseException
694  val hasStoreAccessFault = csrio.exception.bits.uop.cf.exceptionVec(storeAccessFault) && raiseException
695
696  val raiseExceptionVec = csrio.exception.bits.uop.cf.exceptionVec
697  val exceptionNO = ExcPriority.foldRight(0.U)((i: Int, sum: UInt) => Mux(raiseExceptionVec(i), i.U, sum))
698  val causeNO = (raiseIntr << (XLEN-1)).asUInt() | Mux(raiseIntr, intrNO, exceptionNO)
699
700  val raiseExceptionIntr = csrio.exception.valid
701  XSDebug(raiseExceptionIntr, "int/exc: pc %x int (%d):%x exc: (%d):%x\n",
702    csrio.exception.bits.uop.cf.pc, intrNO, intrVec, exceptionNO, raiseExceptionVec.asUInt
703  )
704  XSDebug(raiseExceptionIntr,
705    "pc %x mstatus %x mideleg %x medeleg %x mode %x\n",
706    csrio.exception.bits.uop.cf.pc,
707    mstatus,
708    mideleg,
709    medeleg,
710    priviledgeMode
711  )
712
713  // mtval write logic
714  val memExceptionAddr = SignExt(csrio.memExceptionVAddr, XLEN)
715  when (hasInstrPageFault || hasLoadPageFault || hasStorePageFault) {
716    val tval = Mux(
717      hasInstrPageFault,
718      Mux(
719        csrio.exception.bits.uop.cf.crossPageIPFFix,
720        SignExt(csrio.exception.bits.uop.cf.pc + 2.U, XLEN),
721        SignExt(csrio.exception.bits.uop.cf.pc, XLEN)
722      ),
723      memExceptionAddr
724    )
725    when (priviledgeMode === ModeM) {
726      mtval := tval
727    }.otherwise {
728      stval := tval
729    }
730  }
731
732  when (hasLoadAddrMisaligned || hasStoreAddrMisaligned) {
733    mtval := memExceptionAddr
734  }
735
736  val deleg = Mux(raiseIntr, mideleg , medeleg)
737  // val delegS = ((deleg & (1 << (causeNO & 0xf))) != 0) && (priviledgeMode < ModeM);
738  val delegS = deleg(causeNO(3,0)) && (priviledgeMode < ModeM)
739  val tvalWen = !(hasInstrPageFault || hasLoadPageFault || hasStorePageFault || hasLoadAddrMisaligned || hasStoreAddrMisaligned) || raiseIntr // TODO: need check
740  val isXRet = io.in.valid && func === CSROpType.jmp && !isEcall
741
742  // ctrl block will use theses later for flush
743  val isXRetFlag = RegInit(false.B)
744  val retTargetReg = Reg(retTarget.cloneType)
745  when (io.flushIn) {
746    isXRetFlag := false.B
747  }.elsewhen (isXRet) {
748    isXRetFlag := true.B
749    retTargetReg := retTarget
750  }
751  csrio.isXRet := isXRetFlag
752  csrio.trapTarget := Mux(isXRetFlag,
753    retTargetReg,
754    Mux(delegS, stvec, mtvec)(VAddrBits-1, 0)
755  )
756
757  when (raiseExceptionIntr) {
758    val mstatusOld = WireInit(mstatus.asTypeOf(new MstatusStruct))
759    val mstatusNew = WireInit(mstatus.asTypeOf(new MstatusStruct))
760
761    when (delegS) {
762      scause := causeNO
763      sepc := SignExt(csrio.exception.bits.uop.cf.pc, XLEN)
764      mstatusNew.spp := priviledgeMode
765      mstatusNew.pie.s := mstatusOld.ie.s
766      mstatusNew.ie.s := false.B
767      priviledgeMode := ModeS
768      when (tvalWen) { stval := 0.U }
769    }.otherwise {
770      mcause := causeNO
771      mepc := SignExt(csrio.exception.bits.uop.cf.pc, XLEN)
772      mstatusNew.mpp := priviledgeMode
773      mstatusNew.pie.m := mstatusOld.ie.m
774      mstatusNew.ie.m := false.B
775      priviledgeMode := ModeM
776      when (tvalWen) { mtval := 0.U }
777    }
778
779    mstatus := mstatusNew.asUInt
780  }
781
782  XSDebug(raiseExceptionIntr && delegS, "sepc is writen!!! pc:%x\n", cfIn.pc)
783
784  def readWithScala(addr: Int): UInt = mapping(addr)._1
785
786  val difftestIntrNO = Mux(raiseIntr, causeNO, 0.U)
787
788  if (!env.FPGAPlatform) {
789    difftestIO.intrNO := RegNext(difftestIntrNO)
790    difftestIO.cause := RegNext(Mux(csrio.exception.valid, causeNO, 0.U))
791    difftestIO.priviledgeMode := priviledgeMode
792    difftestIO.mstatus := mstatus
793    difftestIO.sstatus := mstatus & sstatusRmask
794    difftestIO.mepc := mepc
795    difftestIO.sepc := sepc
796    difftestIO.mtval:= mtval
797    difftestIO.stval:= stval
798    difftestIO.mtvec := mtvec
799    difftestIO.stvec := stvec
800    difftestIO.mcause := mcause
801    difftestIO.scause := scause
802    difftestIO.satp := satp
803    difftestIO.mip := mipReg
804    difftestIO.mie := mie
805    difftestIO.mscratch := mscratch
806    difftestIO.sscratch := sscratch
807    difftestIO.mideleg := mideleg
808    difftestIO.medeleg := medeleg
809  }
810}
811