xref: /XiangShan/src/main/scala/xiangshan/backend/fu/PMP.scala (revision b6982e83d6fe4f8c3d111ebc70665f115e470ddf)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.internal.naming.chiselName
22import chisel3.util._
23import utils.MaskedRegMap.WritableMask
24import xiangshan._
25import xiangshan.backend.fu.util.HasCSRConst
26import utils._
27import xiangshan.cache.mmu.{TlbCmd, TlbExceptionBundle}
28
29trait PMPConst {
30  val PMPOffBits = 2 // minimal 4bytes
31}
32
33abstract class PMPBundle(implicit p: Parameters) extends XSBundle with PMPConst {
34  val CoarserGrain: Boolean = PlatformGrain > PMPOffBits
35}
36
37abstract class PMPModule(implicit p: Parameters) extends XSModule with PMPConst with HasCSRConst
38
39@chiselName
40class PMPConfig(implicit p: Parameters) extends PMPBundle {
41  val l = Bool()
42  val res = UInt(2.W)
43  val a = UInt(2.W)
44  val x = Bool()
45  val w = Bool()
46  val r = Bool()
47
48  def off = a === 0.U
49  def tor = a === 1.U
50  def na4 = { if (CoarserGrain) false.B else a === 2.U }
51  def napot = { if (CoarserGrain) a(1).asBool else a === 3.U }
52  def off_tor = !a(1)
53  def na4_napot = a(1)
54
55  def locked = l
56  def addr_locked: Bool = locked
57  def addr_locked(next: PMPConfig): Bool = locked || (next.locked && next.tor)
58
59  def write_cfg_vec(cfgs: UInt): UInt = {
60    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
61    for (i <- cfgVec.indices) {
62      val tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
63      cfgVec(i) := tmp
64      cfgVec(i).w := tmp.w && tmp.r
65      if (CoarserGrain) { cfgVec(i).a := Cat(tmp.a(1), tmp.a.orR) }
66    }
67    cfgVec.asUInt
68  }
69
70  def write_cfg_vec(mask: Vec[UInt], addr: Vec[UInt], index: Int)(cfgs: UInt): UInt = {
71    val cfgVec = Wire(Vec(cfgs.getWidth/8, new PMPConfig))
72    for (i <- cfgVec.indices) {
73      val tmp = cfgs((i+1)*8-1, i*8).asUInt.asTypeOf(new PMPConfig)
74      cfgVec(i) := tmp
75      cfgVec(i).w := tmp.w && tmp.r
76      if (CoarserGrain) { cfgVec(i).a := Cat(tmp.a(1), tmp.a.orR) }
77      when (cfgVec(i).na4_napot) {
78        mask(index + i) := new PMPEntry().match_mask(cfgVec(i), addr(index + i))
79      }
80    }
81    cfgVec.asUInt
82  }
83
84  def reset() = {
85    l := false.B
86    a := 0.U
87  }
88}
89
90/** PMPBase for CSR unit
91  * with only read and write logic
92  */
93@chiselName
94class PMPBase(implicit p: Parameters) extends PMPBundle {
95  val cfg = new PMPConfig
96  val addr = UInt((PAddrBits - PMPOffBits).W)
97
98  /** In general, the PMP grain is 2**{G+2} bytes. when G >= 1, na4 is not selectable.
99    * When G >= 2 and cfg.a(1) is set(then the mode is napot), the bits addr(G-2, 0) read as zeros.
100    * When G >= 1 and cfg.a(1) is clear(the mode is off or tor), the addr(G-1, 0) read as zeros.
101    * The low OffBits is dropped
102    */
103  def read_addr(): UInt = {
104    read_addr(cfg)(addr)
105  }
106
107  def read_addr(cfg: PMPConfig)(addr: UInt): UInt = {
108    val G = PlatformGrain - PMPOffBits
109    require(G >= 0)
110    if (G == 0) {
111      addr
112    } else if (G >= 2) {
113      Mux(cfg.na4_napot, set_low_bits(addr, G-1), clear_low_bits(addr, G))
114    } else { // G is 1
115      Mux(cfg.off_tor, clear_low_bits(addr, G), addr)
116    }
117  }
118  /** addr for inside addr, drop OffBits with.
119    * compare_addr for inside addr for comparing.
120    * paddr for outside addr.
121    */
122  def write_addr(next: PMPBase)(paddr: UInt) = {
123    Mux(!cfg.addr_locked(next.cfg), paddr, addr)
124  }
125  def write_addr(paddr: UInt) = {
126    Mux(!cfg.addr_locked, paddr, addr)
127  }
128
129  def set_low_bits(data: UInt, num: Int): UInt = {
130    require(num >= 0)
131    data | ((1 << num)-1).U
132  }
133
134  /** mask the data's low num bits (lsb) */
135  def clear_low_bits(data: UInt, num: Int): UInt = {
136    require(num >= 0)
137    // use Cat instead of & with mask to avoid "Signal Width" problem
138    if (num == 0) { data }
139    else { Cat(data(data.getWidth-1, num), 0.U(num.W)) }
140  }
141
142  def gen(cfg: PMPConfig, addr: UInt) = {
143    require(addr.getWidth == this.addr.getWidth)
144    this.cfg := cfg
145    this.addr := addr
146  }
147}
148
149/** PMPEntry for outside pmp copies
150  * with one more elements mask to help napot match
151  * TODO: make mask an element, not an method, for timing opt
152  */
153@chiselName
154class PMPEntry(implicit p: Parameters) extends PMPBase {
155  val mask = UInt(PAddrBits.W) // help to match in napot
156
157  /** compare_addr is used to compare with input addr */
158  def compare_addr = ((addr << PMPOffBits) & ~(((1 << PlatformGrain) - 1).U(PAddrBits.W))).asUInt
159
160  def write_addr(next: PMPBase, mask: UInt)(paddr: UInt) = {
161    mask := Mux(!cfg.addr_locked(next.cfg), match_mask(paddr), mask)
162    Mux(!cfg.addr_locked(next.cfg), paddr, addr)
163  }
164
165  def write_addr(mask: UInt)(paddr: UInt) = {
166    mask := Mux(!cfg.addr_locked, match_mask(paddr), mask)
167    Mux(!cfg.addr_locked, paddr, addr)
168  }
169  /** size and maxSize are all log2 Size
170    * for dtlb, the maxSize is bXLEN which is 8
171    * for itlb and ptw, the maxSize is log2(512) ?
172    * but we may only need the 64 bytes? how to prevent the bugs?
173    * TODO: handle the special case that itlb & ptw & dcache access wider size than XLEN
174    */
175  def is_match(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry): Bool = {
176    Mux(cfg.na4_napot, napotMatch(paddr, lgSize, lgMaxSize),
177      Mux(cfg.tor, torMatch(paddr, lgSize, lgMaxSize, last_pmp), false.B))
178  }
179
180  /** generate match mask to help match in napot mode */
181  def match_mask(paddr: UInt) = {
182    val tmp_addr = Cat(paddr, cfg.a(0)) | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
183    Cat(tmp_addr & ~(tmp_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
184  }
185
186  def match_mask(cfg: PMPConfig, paddr: UInt) = {
187    val tmp_addr = Cat(paddr, cfg.a(0)) | (((1 << PlatformGrain) - 1) >> PMPOffBits).U((paddr.getWidth + 1).W)
188    Cat(tmp_addr & ~(tmp_addr + 1.U), ((1 << PMPOffBits) - 1).U(PMPOffBits.W))
189  }
190
191  def boundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int) = {
192    if (lgMaxSize <= PlatformGrain) {
193      paddr < compare_addr
194    } else {
195      val highLess = (paddr >> lgMaxSize) < (compare_addr >> lgMaxSize)
196      val highEqual = (paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)
197      val lowLess = (paddr(lgMaxSize-1, 0) | OneHot.UIntToOH1(lgSize, lgMaxSize))  < compare_addr(lgMaxSize-1, 0)
198      highLess || (highEqual && lowLess)
199    }
200  }
201
202  def lowerBoundMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int) = {
203    !boundMatch(paddr, lgSize, lgMaxSize)
204  }
205
206  def higherBoundMatch(paddr: UInt, lgMaxSize: Int) = {
207    boundMatch(paddr, 0.U, lgMaxSize)
208  }
209
210  def torMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last_pmp: PMPEntry) = {
211    last_pmp.lowerBoundMatch(paddr, lgSize, lgMaxSize) && higherBoundMatch(paddr, lgMaxSize)
212  }
213
214  def unmaskEqual(a: UInt, b: UInt, m: UInt) = {
215    (a & ~m) === (b & ~m)
216  }
217
218  def napotMatch(paddr: UInt, lgSize: UInt, lgMaxSize: Int) = {
219    if (lgMaxSize <= PlatformGrain) {
220      unmaskEqual(paddr, compare_addr, mask)
221    } else {
222      val lowMask = mask | OneHot.UIntToOH1(lgSize, lgMaxSize)
223      val highMatch = unmaskEqual(paddr >> lgMaxSize, compare_addr >> lgMaxSize, mask >> lgMaxSize)
224      val lowMatch = unmaskEqual(paddr(lgMaxSize-1, 0), compare_addr(lgMaxSize-1, 0), lowMask(lgMaxSize-1, 0))
225      highMatch && lowMatch
226    }
227  }
228
229  def aligned(paddr: UInt, lgSize: UInt, lgMaxSize: Int, last: PMPEntry) = {
230    if (lgMaxSize <= PlatformGrain) {
231      true.B
232    } else {
233      val lowBitsMask = OneHot.UIntToOH1(lgSize, lgMaxSize)
234      val lowerBound = ((paddr >> lgMaxSize) === (last.compare_addr >> lgMaxSize)) &&
235        ((~paddr(lgMaxSize-1, 0) & last.compare_addr(lgMaxSize-1, 0)) =/= 0.U)
236      val upperBound = ((paddr >> lgMaxSize) === (compare_addr >> lgMaxSize)) &&
237        ((compare_addr(lgMaxSize-1, 0) & (paddr(lgMaxSize-1, 0) | lowBitsMask)) =/= 0.U)
238      val torAligned = !(lowerBound || upperBound)
239      val napotAligned = (lowBitsMask & ~mask(lgMaxSize-1, 0)) === 0.U
240      Mux(cfg.na4_napot, napotAligned, torAligned)
241    }
242  }
243
244  def gen(cfg: PMPConfig, addr: UInt, mask: UInt) = {
245    require(addr.getWidth == this.addr.getWidth)
246    this.cfg := cfg
247    this.addr := addr
248    this.mask := mask
249  }
250
251  def reset() = {
252    cfg.l := 0.U
253    cfg.a := 0.U
254  }
255}
256
257@chiselName
258class PMP(implicit p: Parameters) extends PMPModule {
259  val io = IO(new Bundle {
260    val distribute_csr = Flipped(new DistributedCSRIO())
261    val pmp = Output(Vec(NumPMP, new PMPEntry()))
262  })
263
264  val w = io.distribute_csr.w
265
266  val pmp = Wire(Vec(NumPMP, new PMPEntry()))
267
268  val pmpCfgPerCSR = XLEN / new PMPConfig().getWidth
269  def pmpCfgIndex(i: Int) = (XLEN / 32) * (i / pmpCfgPerCSR)
270
271  /** to fit MaskedRegMap's write, declare cfgs as Merged CSRs and split them into each pmp */
272  val cfgMerged = RegInit(VecInit(Seq.fill(NumPMP / pmpCfgPerCSR)(0.U(XLEN.W))))
273  val cfgs = WireInit(cfgMerged).asTypeOf(Vec(NumPMP, new PMPConfig()))
274  val addr = Reg(Vec(NumPMP, UInt((PAddrBits-PMPOffBits).W)))
275  val mask = Reg(Vec(NumPMP, UInt(PAddrBits.W)))
276
277  for (i <- pmp.indices) {
278    pmp(i).gen(cfgs(i), addr(i), mask(i))
279  }
280
281  val cfg_mapping = (0 until NumPMP by pmpCfgPerCSR).map(i => {Map(
282    MaskedRegMap(
283      addr = PmpcfgBase + pmpCfgIndex(i),
284      reg = cfgMerged(i/pmpCfgPerCSR),
285      wmask = WritableMask,
286      wfn = new PMPConfig().write_cfg_vec(mask, addr, i)
287    ))
288  }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes
289
290  val addr_mapping = (0 until NumPMP).map(i => {Map(
291    MaskedRegMap(
292      addr = PmpaddrBase + i,
293      reg = addr(i),
294      wmask = WritableMask,
295      wfn = { if (i != NumPMP-1) pmp(i).write_addr(pmp(i+1), mask(i)) else pmp(i).write_addr(mask(i)) },
296      rmask = WritableMask,
297      rfn = new PMPBase().read_addr(pmp(i).cfg)
298    ))
299  }).fold(Map())((a, b) => a ++ b) // ugly code, hit me if u have better codes.
300  val pmpMapping =  cfg_mapping ++ addr_mapping
301
302  val rdata = Wire(UInt(XLEN.W))
303  MaskedRegMap.generate(pmpMapping, w.bits.addr, rdata, w.valid, w.bits.data)
304
305  io.pmp := pmp
306}
307
308class PMPReqBundle(lgMaxSize: Int = 3)(implicit p: Parameters) extends PMPBundle {
309  val addr = Output(UInt(PAddrBits.W))
310  val size = Output(UInt(log2Ceil(lgMaxSize+1).W))
311  val cmd = Output(TlbCmd())
312
313  override def cloneType = (new PMPReqBundle(lgMaxSize)).asInstanceOf[this.type]
314}
315
316class PMPRespBundle(implicit p: Parameters) extends TlbExceptionBundle
317
318@chiselName
319class PMPChecker(lgMaxSize: Int = 3, sameCycle: Boolean = false)(implicit p: Parameters) extends PMPModule {
320  val io = IO(new Bundle{
321    val env = Input(new Bundle {
322      val mode = Input(UInt(2.W))
323      val pmp = Input(Vec(NumPMP, new PMPEntry()))
324    })
325    val req = Flipped(Valid(new PMPReqBundle(lgMaxSize))) // usage: assign the valid to fire signal
326    val resp = Output(new PMPRespBundle())
327  })
328
329  val req = io.req.bits
330
331  val passThrough = if (io.env.pmp.isEmpty) true.B else (io.env.mode > ModeS)
332  val pmpMinuxOne = WireInit(0.U.asTypeOf(new PMPEntry()))
333  pmpMinuxOne.cfg.r := passThrough
334  pmpMinuxOne.cfg.w := passThrough
335  pmpMinuxOne.cfg.x := passThrough
336
337  val match_wave = Wire(Vec(NumPMP, Bool()))
338  val ignore_wave = Wire(Vec(NumPMP, Bool()))
339  val aligned_wave = Wire(Vec(NumPMP, Bool()))
340  val prev_wave = Wire(Vec(NumPMP, new PMPEntry()))
341  val cur_wave = Wire(Vec(NumPMP, new PMPEntry()))
342
343  dontTouch(match_wave)
344  dontTouch(ignore_wave)
345  dontTouch(aligned_wave)
346  dontTouch(prev_wave)
347  dontTouch(cur_wave)
348
349  val res = io.env.pmp.zip(pmpMinuxOne +: io.env.pmp.take(NumPMP-1)).zipWithIndex
350    .reverse.foldLeft(pmpMinuxOne) { case (prev, ((pmp, last_pmp), i)) =>
351    val is_match = pmp.is_match(req.addr, req.size, lgMaxSize, last_pmp)
352    val ignore = passThrough && !pmp.cfg.l
353    val aligned = pmp.aligned(req.addr, req.size, lgMaxSize, last_pmp)
354
355    val cur = WireInit(pmp)
356    cur.cfg.r := aligned && (pmp.cfg.r || ignore)
357    cur.cfg.w := aligned && (pmp.cfg.w || ignore)
358    cur.cfg.x := aligned && (pmp.cfg.x || ignore)
359
360    match_wave(i) := is_match
361    ignore_wave(i) := ignore
362    aligned_wave(i) := aligned
363    cur_wave(i) := cur
364    prev_wave(i) := prev
365
366    XSDebug(p"pmp${i.U} cfg:${Hexadecimal(pmp.cfg.asUInt)} addr:${Hexadecimal(pmp.addr)} mask:${Hexadecimal(pmp.mask)} is_match:${is_match} aligned:${aligned}")
367
368    Mux(is_match, cur, prev)
369  }
370
371  // NOTE: if itlb or dtlb may get blocked, this may also need do it
372  val ld = TlbCmd.isRead(req.cmd) && !TlbCmd.isAtom(req.cmd) && !res.cfg.r
373  val st = (TlbCmd.isWrite(req.cmd) || TlbCmd.isAtom(req.cmd)) && !res.cfg.w
374  val instr = TlbCmd.isExec(req.cmd) && !res.cfg.x
375  if (sameCycle) {
376    io.resp.ld := ld
377    io.resp.st := st
378    io.resp.instr := instr
379  } else {
380    io.resp.ld := RegEnable(ld, io.req.valid)
381    io.resp.st := RegEnable(st, io.req.valid)
382    io.resp.instr := RegEnable(instr, io.req.valid)
383  }
384}
385