xref: /XiangShan/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala (revision 30f35717e23156cb95b30a36db530384545b48a4)
1/***************************************************************************************
2* Copyright (c) 2024 Beijing Institute of Open Source Chip (BOSC)
3* Copyright (c) 2020-2024 Institute of Computing Technology, Chinese Academy of Sciences
4* Copyright (c) 2020-2021 Peng Cheng Laboratory
5*
6* XiangShan is licensed under Mulan PSL v2.
7* You can use this software according to the terms and conditions of the Mulan PSL v2.
8* You may obtain a copy of Mulan PSL v2 at:
9*          http://license.coscl.org.cn/MulanPSL2
10*
11* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
12* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
13* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
14*
15* See the Mulan PSL v2 for more details.
16*
17*
18* Acknowledgement
19*
20* This implementation is inspired by several key papers:
21* [1] Stephen Somogyi, Thomas F. Wenisch, Anastassia Ailamaki, Babak Falsafi and Andreas Moshovos. "[Spatial memory
22* streaming.](https://doi.org/10.1109/ISCA.2006.38)" 33rd International Symposium on Computer Architecture (ISCA).
23* 2006.
24***************************************************************************************/
25
26package xiangshan.mem.prefetch
27
28import org.chipsalliance.cde.config.Parameters
29import chisel3._
30import chisel3.util._
31import utils._
32import utility._
33import xiangshan._
34import xiangshan.backend.fu.PMPRespBundle
35import xiangshan.mem.L1PrefetchReq
36import xiangshan.mem.Bundles.LsPrefetchTrainBundle
37import xiangshan.mem.trace._
38import xiangshan.mem.HasL1PrefetchSourceParameter
39import xiangshan.cache.HasDCacheParameters
40import xiangshan.cache.mmu._
41
42case class SMSParams
43(
44  region_size: Int = 1024,
45  vaddr_hash_width: Int = 5,
46  block_addr_raw_width: Int = 10,
47  stride_pc_bits: Int = 10,
48  max_stride: Int = 1024,
49  stride_entries: Int = 16,
50  active_gen_table_size: Int = 16,
51  pht_size: Int = 64,
52  pht_ways: Int = 2,
53  pht_hist_bits: Int = 2,
54  pht_tag_bits: Int = 13,
55  pht_lookup_queue_size: Int = 4,
56  pf_filter_size: Int = 16,
57  train_filter_size: Int = 8
58) extends PrefetcherParams
59
60trait HasSMSModuleHelper extends HasCircularQueuePtrHelper with HasDCacheParameters
61{ this: HasXSParameter =>
62  val smsParams = coreParams.prefetcher.get.asInstanceOf[SMSParams]
63  val BLK_ADDR_WIDTH = VAddrBits - log2Up(dcacheParameters.blockBytes)
64  val REGION_SIZE = smsParams.region_size
65  val REGION_BLKS = smsParams.region_size / dcacheParameters.blockBytes
66  val REGION_ADDR_BITS = VAddrBits - log2Up(REGION_SIZE)
67  val REGION_OFFSET = log2Up(REGION_BLKS)
68  val VADDR_HASH_WIDTH = smsParams.vaddr_hash_width
69  val BLK_ADDR_RAW_WIDTH = smsParams.block_addr_raw_width
70  val REGION_ADDR_RAW_WIDTH = BLK_ADDR_RAW_WIDTH - REGION_OFFSET
71  val BLK_TAG_WIDTH = BLK_ADDR_RAW_WIDTH + VADDR_HASH_WIDTH
72  val REGION_TAG_WIDTH = REGION_ADDR_RAW_WIDTH + VADDR_HASH_WIDTH
73  val PHT_INDEX_BITS = log2Up(smsParams.pht_size / smsParams.pht_ways)
74  val PHT_TAG_BITS = smsParams.pht_tag_bits
75  val PHT_HIST_BITS = smsParams.pht_hist_bits
76  // page bit index in block addr
77  val BLOCK_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / dcacheParameters.blockBytes)
78  val REGION_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / smsParams.region_size)
79  val STRIDE_PC_BITS = smsParams.stride_pc_bits
80  val STRIDE_BLK_ADDR_BITS = log2Up(smsParams.max_stride)
81
82  def block_addr(x: UInt): UInt = {
83    val offset = log2Up(dcacheParameters.blockBytes)
84    x(x.getWidth - 1, offset)
85  }
86
87  def region_addr(x: UInt): UInt = {
88    val offset = log2Up(REGION_SIZE)
89    x(x.getWidth - 1, offset)
90  }
91
92  def region_offset_to_bits(off: UInt): UInt = {
93    (1.U << off).asUInt
94  }
95
96  def region_hash_tag(rg_addr: UInt): UInt = {
97    val low = rg_addr(REGION_ADDR_RAW_WIDTH - 1, 0)
98    val high = rg_addr(REGION_ADDR_RAW_WIDTH + 3 * VADDR_HASH_WIDTH - 1, REGION_ADDR_RAW_WIDTH)
99    val high_hash = vaddr_hash(high)
100    Cat(high_hash, low)
101  }
102
103  def page_bit(region_addr: UInt): UInt = {
104    region_addr(log2Up(dcacheParameters.pageSize/REGION_SIZE))
105  }
106
107  def block_hash_tag(x: UInt): UInt = {
108    val blk_addr = block_addr(x)
109    val low = blk_addr(BLK_ADDR_RAW_WIDTH - 1, 0)
110    val high = blk_addr(BLK_ADDR_RAW_WIDTH - 1 + 3 * VADDR_HASH_WIDTH, BLK_ADDR_RAW_WIDTH)
111    val high_hash = vaddr_hash(high)
112    Cat(high_hash, low)
113  }
114
115  def vaddr_hash(x: UInt): UInt = {
116    val width = VADDR_HASH_WIDTH
117    val low = x(width - 1, 0)
118    val mid = x(2 * width - 1, width)
119    val high = x(3 * width - 1, 2 * width)
120    low ^ mid ^ high
121  }
122
123  def pht_index(pc: UInt): UInt = {
124    val low_bits = pc(PHT_INDEX_BITS, 2)
125    val hi_bit = pc(1) ^ pc(PHT_INDEX_BITS+1)
126    Cat(hi_bit, low_bits)
127  }
128
129  def pht_tag(pc: UInt): UInt = {
130    pc(PHT_INDEX_BITS + 2 + PHT_TAG_BITS - 1, PHT_INDEX_BITS + 2)
131  }
132
133  def get_alias_bits(region_vaddr: UInt): UInt = {
134    val offset = log2Up(REGION_SIZE)
135    get_alias(Cat(region_vaddr, 0.U(offset.W)))
136  }
137}
138
139class StridePF()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
140  val io = IO(new Bundle() {
141    val stride_en = Input(Bool())
142    val s0_lookup = Flipped(new ValidIO(new Bundle() {
143      val pc = UInt(STRIDE_PC_BITS.W)
144      val vaddr = UInt(VAddrBits.W)
145      val paddr = UInt(PAddrBits.W)
146    }))
147    val s1_valid = Input(Bool())
148    val s2_gen_req = ValidIO(new PfGenReq())
149  })
150
151  val prev_valid = GatedValidRegNext(io.s0_lookup.valid, false.B)
152  val prev_pc = RegEnable(io.s0_lookup.bits.pc, io.s0_lookup.valid)
153
154  val s0_valid = io.s0_lookup.valid && !(prev_valid && prev_pc === io.s0_lookup.bits.pc)
155
156  def entry_map[T](fn: Int => T) = (0 until smsParams.stride_entries).map(fn)
157
158  val replacement = ReplacementPolicy.fromString("plru", smsParams.stride_entries)
159  val valids = entry_map(_ => RegInit(false.B))
160  val entries_pc = entry_map(_ => Reg(UInt(STRIDE_PC_BITS.W)) )
161  val entries_conf = entry_map(_ => RegInit(1.U(2.W)))
162  val entries_last_addr = entry_map(_ => Reg(UInt(STRIDE_BLK_ADDR_BITS.W)) )
163  val entries_stride = entry_map(_ => Reg(SInt((STRIDE_BLK_ADDR_BITS+1).W)))
164
165
166  val s0_match_vec = valids.zip(entries_pc).map({
167    case (v, pc) => v && pc === io.s0_lookup.bits.pc
168  })
169
170  val s0_hit = s0_valid && Cat(s0_match_vec).orR
171  val s0_miss = s0_valid && !s0_hit
172  val s0_matched_conf = Mux1H(s0_match_vec, entries_conf)
173  val s0_matched_last_addr = Mux1H(s0_match_vec, entries_last_addr)
174  val s0_matched_last_stride = Mux1H(s0_match_vec, entries_stride)
175
176  val s1_hit = GatedValidRegNext(s0_hit) && io.s1_valid
177  val s1_alloc = GatedValidRegNext(s0_miss) && io.s1_valid
178  val s1_vaddr = RegEnable(io.s0_lookup.bits.vaddr, s0_valid)
179  val s1_paddr = RegEnable(io.s0_lookup.bits.paddr, s0_valid)
180  val s1_conf = RegEnable(s0_matched_conf, s0_valid)
181  val s1_last_addr = RegEnable(s0_matched_last_addr, s0_valid)
182  val s1_last_stride = RegEnable(s0_matched_last_stride, s0_valid)
183  val s1_match_vec = RegEnable(VecInit(s0_match_vec), s0_valid)
184
185  val BLOCK_OFFSET = log2Up(dcacheParameters.blockBytes)
186  val s1_new_stride_vaddr = s1_vaddr(BLOCK_OFFSET + STRIDE_BLK_ADDR_BITS - 1, BLOCK_OFFSET)
187  val s1_new_stride = (0.U(1.W) ## s1_new_stride_vaddr).asSInt - (0.U(1.W) ## s1_last_addr).asSInt
188  val s1_stride_non_zero = s1_last_stride =/= 0.S
189  val s1_stride_match = s1_new_stride === s1_last_stride && s1_stride_non_zero
190  val s1_replace_idx = replacement.way
191
192  for(i <- 0 until smsParams.stride_entries){
193    val alloc = s1_alloc && i.U === s1_replace_idx
194    val update = s1_hit && s1_match_vec(i)
195    when(update){
196      assert(valids(i))
197      entries_conf(i) := Mux(s1_stride_match,
198        Mux(s1_conf === 3.U, 3.U, s1_conf + 1.U),
199        Mux(s1_conf === 0.U, 0.U, s1_conf - 1.U)
200      )
201      entries_last_addr(i) := s1_new_stride_vaddr
202      when(!s1_conf(1)){
203        entries_stride(i) := s1_new_stride
204      }
205    }
206    when(alloc){
207      valids(i) := true.B
208      entries_pc(i) := prev_pc
209      entries_conf(i) := 0.U
210      entries_last_addr(i) := s1_new_stride_vaddr
211      entries_stride(i) := 0.S
212    }
213    assert(!(update && alloc))
214  }
215  when(s1_hit){
216    replacement.access(OHToUInt(s1_match_vec.asUInt))
217  }.elsewhen(s1_alloc){
218    replacement.access(s1_replace_idx)
219  }
220
221  val s1_block_vaddr = block_addr(s1_vaddr)
222  val s1_pf_block_vaddr = (s1_block_vaddr.asSInt + s1_last_stride).asUInt
223  val s1_pf_cross_page = s1_pf_block_vaddr(BLOCK_ADDR_PAGE_BIT) =/= s1_block_vaddr(BLOCK_ADDR_PAGE_BIT)
224
225  val s2_pf_gen_valid = GatedValidRegNext(s1_hit && s1_stride_match, false.B)
226  val s2_pf_gen_paddr_valid = RegEnable(!s1_pf_cross_page, s1_hit && s1_stride_match)
227  val s2_pf_block_vaddr = RegEnable(s1_pf_block_vaddr, s1_hit && s1_stride_match)
228  val s2_block_paddr = RegEnable(block_addr(s1_paddr), s1_hit && s1_stride_match)
229
230  val s2_pf_block_addr = Mux(s2_pf_gen_paddr_valid,
231    Cat(
232      s2_block_paddr(PAddrBits - BLOCK_OFFSET - 1, BLOCK_ADDR_PAGE_BIT),
233      s2_pf_block_vaddr(BLOCK_ADDR_PAGE_BIT - 1, 0)
234    ),
235    s2_pf_block_vaddr
236  )
237  val s2_pf_full_addr = Wire(UInt(VAddrBits.W))
238  s2_pf_full_addr := s2_pf_block_addr ## 0.U(BLOCK_OFFSET.W)
239
240  val s2_pf_region_addr = region_addr(s2_pf_full_addr)
241  val s2_pf_region_offset = s2_pf_block_addr(REGION_OFFSET - 1, 0)
242
243  val s2_full_vaddr = Wire(UInt(VAddrBits.W))
244  s2_full_vaddr := s2_pf_block_vaddr ## 0.U(BLOCK_OFFSET.W)
245
246  val s2_region_tag = region_hash_tag(region_addr(s2_full_vaddr))
247
248  io.s2_gen_req.valid := s2_pf_gen_valid && io.stride_en
249  io.s2_gen_req.bits.region_tag := s2_region_tag
250  io.s2_gen_req.bits.region_addr := s2_pf_region_addr
251  io.s2_gen_req.bits.alias_bits := get_alias_bits(region_addr(s2_full_vaddr))
252  io.s2_gen_req.bits.region_bits := region_offset_to_bits(s2_pf_region_offset)
253  io.s2_gen_req.bits.paddr_valid := s2_pf_gen_paddr_valid
254  io.s2_gen_req.bits.decr_mode := false.B
255  io.s2_gen_req.bits.debug_source_type := HW_PREFETCH_STRIDE.U
256
257}
258
259class AGTEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
260  val pht_index = UInt(PHT_INDEX_BITS.W)
261  val pht_tag = UInt(PHT_TAG_BITS.W)
262  val region_bits = UInt(REGION_BLKS.W)
263  val region_bit_single = UInt(REGION_BLKS.W)
264  val region_tag = UInt(REGION_TAG_WIDTH.W)
265  val region_offset = UInt(REGION_OFFSET.W)
266  val access_cnt = UInt((REGION_BLKS-1).U.getWidth.W)
267  val decr_mode = Bool()
268  val single_update = Bool()//this is a signal update request
269  val has_been_signal_updated = Bool()
270}
271
272class PfGenReq()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
273  val region_tag = UInt(REGION_TAG_WIDTH.W)
274  val region_addr = UInt(REGION_ADDR_BITS.W)
275  val region_bits = UInt(REGION_BLKS.W)
276  val paddr_valid = Bool()
277  val decr_mode = Bool()
278  val alias_bits = UInt(2.W)
279  val debug_source_type = UInt(log2Up(nSourceType).W)
280}
281
282class AGTEvictReq()(implicit p: Parameters) extends XSBundle {
283  val vaddr = UInt(VAddrBits.W)
284}
285
286class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
287  val io = IO(new Bundle() {
288    val agt_en = Input(Bool())
289    val s0_lookup = Flipped(ValidIO(new Bundle() {
290      val region_tag = UInt(REGION_TAG_WIDTH.W)
291      val region_p1_tag = UInt(REGION_TAG_WIDTH.W)
292      val region_m1_tag = UInt(REGION_TAG_WIDTH.W)
293      val region_offset = UInt(REGION_OFFSET.W)
294      val pht_index = UInt(PHT_INDEX_BITS.W)
295      val pht_tag = UInt(PHT_TAG_BITS.W)
296      val allow_cross_region_p1 = Bool()
297      val allow_cross_region_m1 = Bool()
298      val region_p1_cross_page = Bool()
299      val region_m1_cross_page = Bool()
300      val region_paddr = UInt(REGION_ADDR_BITS.W)
301      val region_vaddr = UInt(REGION_ADDR_BITS.W)
302    }))
303    // dcache has released a block, evict it from agt
304    val s0_dcache_evict = Flipped(DecoupledIO(new AGTEvictReq))
305    val s1_sel_stride = Output(Bool())
306    val s2_stride_hit = Input(Bool())
307    // if agt/stride missed, try lookup pht
308    val s2_pht_lookup = ValidIO(new PhtLookup())
309    // evict entry to pht
310    val s2_evict = ValidIO(new AGTEntry())
311    val s2_pf_gen_req = ValidIO(new PfGenReq())
312    val act_threshold = Input(UInt(REGION_OFFSET.W))
313    val act_stride = Input(UInt(6.W))
314  })
315
316  val entries = Seq.fill(smsParams.active_gen_table_size){ Reg(new AGTEntry()) }
317  val valids = Seq.fill(smsParams.active_gen_table_size){ RegInit(false.B) }
318  val replacement = ReplacementPolicy.fromString("plru", smsParams.active_gen_table_size)
319
320  val s1_replace_mask_w = Wire(UInt(smsParams.active_gen_table_size.W))
321
322  val s0_lookup = io.s0_lookup.bits
323  val s0_lookup_valid = io.s0_lookup.valid
324
325  val s0_dcache_evict = io.s0_dcache_evict.bits
326  val s0_dcache_evict_valid = io.s0_dcache_evict.valid
327  val s0_dcache_evict_tag = block_hash_tag(s0_dcache_evict.vaddr).head(REGION_TAG_WIDTH)
328
329  val prev_lookup = RegEnable(s0_lookup, s0_lookup_valid)
330  val prev_lookup_valid = GatedValidRegNext(s0_lookup_valid, false.B)
331
332  val s0_match_prev = prev_lookup_valid && s0_lookup.region_tag === prev_lookup.region_tag
333
334  def gen_match_vec(region_tag: UInt): Seq[Bool] = {
335    entries.zip(valids).map({
336      case (ent, v) => v && ent.region_tag === region_tag
337    })
338  }
339
340  val region_match_vec_s0 = gen_match_vec(s0_lookup.region_tag)
341  val region_p1_match_vec_s0 = gen_match_vec(s0_lookup.region_p1_tag)
342  val region_m1_match_vec_s0 = gen_match_vec(s0_lookup.region_m1_tag)
343
344  val any_region_match = Cat(region_match_vec_s0).orR
345  val any_region_p1_match = Cat(region_p1_match_vec_s0).orR && s0_lookup.allow_cross_region_p1
346  val any_region_m1_match = Cat(region_m1_match_vec_s0).orR && s0_lookup.allow_cross_region_m1
347
348  val region_match_vec_dcache_evict_s0 = gen_match_vec(s0_dcache_evict_tag)
349  val any_region_dcache_evict_match = Cat(region_match_vec_dcache_evict_s0).orR
350  // s0 dcache evict a entry that may be replaced in s1
351  val s0_dcache_evict_conflict = Cat(VecInit(region_match_vec_dcache_evict_s0).asUInt & s1_replace_mask_w).orR
352  val s0_do_dcache_evict = io.s0_dcache_evict.fire && any_region_dcache_evict_match
353
354  io.s0_dcache_evict.ready := !s0_lookup_valid && !s0_dcache_evict_conflict
355
356  val s0_region_hit = any_region_match
357  val s0_cross_region_hit = any_region_m1_match || any_region_p1_match
358  val s0_alloc = s0_lookup_valid && !s0_region_hit && !s0_match_prev
359  val s0_pf_gen_match_vec = valids.indices.map(i => {
360    Mux(any_region_match,
361      region_match_vec_s0(i),
362      Mux(any_region_m1_match,
363        region_m1_match_vec_s0(i), region_p1_match_vec_s0(i)
364      )
365    )
366  })
367  val s0_agt_entry = Wire(new AGTEntry())
368
369  s0_agt_entry.pht_index := s0_lookup.pht_index
370  s0_agt_entry.pht_tag := s0_lookup.pht_tag
371  s0_agt_entry.region_bits := region_offset_to_bits(s0_lookup.region_offset)
372  // update bits this time
373  s0_agt_entry.region_bit_single := region_offset_to_bits(s0_lookup.region_offset)
374  s0_agt_entry.region_tag := s0_lookup.region_tag
375  s0_agt_entry.region_offset := s0_lookup.region_offset
376  s0_agt_entry.access_cnt := 1.U
377
378  s0_agt_entry.has_been_signal_updated := false.B
379  // lookup_region + 1 == entry_region
380  // lookup_region = entry_region - 1 => decr mode
381  s0_agt_entry.decr_mode := !s0_region_hit && !any_region_m1_match && any_region_p1_match
382  val s0_replace_way = replacement.way
383  val s0_replace_mask = UIntToOH(s0_replace_way)
384  // s0 hit a entry that may be replaced in s1
385  val s0_update_conflict = Cat(VecInit(region_match_vec_s0).asUInt & s1_replace_mask_w).orR
386  val s0_update = s0_lookup_valid && s0_region_hit && !s0_update_conflict
387  s0_agt_entry.single_update := s0_update
388
389  val s0_access_way = Mux1H(
390    Seq(s0_update, s0_alloc),
391    Seq(OHToUInt(region_match_vec_s0), s0_replace_way)
392  )
393  when(s0_update || s0_alloc) {
394    replacement.access(s0_access_way)
395  }
396
397  // stage1: update/alloc
398  // region hit, update entry
399  val s1_update = GatedValidRegNext(s0_update, false.B)
400  val s1_update_mask = RegEnable(VecInit(region_match_vec_s0), s0_lookup_valid)
401  val s1_agt_entry = RegEnable(s0_agt_entry, s0_lookup_valid)
402  val s1_cross_region_match = RegEnable(s0_cross_region_hit, s0_lookup_valid)
403  val s1_alloc = GatedValidRegNext(s0_alloc, false.B)
404  val s1_alloc_entry = s1_agt_entry
405  val s1_do_dcache_evict = GatedValidRegNext(s0_do_dcache_evict, false.B)
406  val s1_replace_mask = Mux(
407    s1_do_dcache_evict,
408    RegEnable(VecInit(region_match_vec_dcache_evict_s0).asUInt, s0_do_dcache_evict),
409    RegEnable(s0_replace_mask, s0_lookup_valid)
410  )
411  s1_replace_mask_w := s1_replace_mask & Fill(smsParams.active_gen_table_size, s1_alloc || s1_do_dcache_evict)
412  val s1_evict_entry = Mux1H(s1_replace_mask, entries)
413  val s1_evict_valid = Mux1H(s1_replace_mask, valids)
414  // pf gen
415  val s1_pf_gen_match_vec = RegEnable(VecInit(s0_pf_gen_match_vec), s0_lookup_valid)
416  val s1_region_paddr = RegEnable(s0_lookup.region_paddr, s0_lookup_valid)
417  val s1_region_vaddr = RegEnable(s0_lookup.region_vaddr, s0_lookup_valid)
418  val s1_region_offset = RegEnable(s0_lookup.region_offset, s0_lookup_valid)
419  val s1_bit_region_signal = RegEnable(region_offset_to_bits(s0_lookup.region_offset), s0_lookup_valid)
420
421  for(i <- entries.indices){
422    val alloc = s1_replace_mask(i) && s1_alloc
423    val update = s1_update_mask(i) && s1_update
424    val update_entry = WireInit(entries(i))
425    update_entry.region_bits := entries(i).region_bits | s1_agt_entry.region_bits
426    update_entry.access_cnt := Mux(entries(i).access_cnt === (REGION_BLKS - 1).U,
427      entries(i).access_cnt,
428      entries(i).access_cnt + (s1_agt_entry.region_bits & (~entries(i).region_bits).asUInt).orR
429    )
430    update_entry.region_bit_single := s1_agt_entry.region_bit_single
431    update_entry.has_been_signal_updated := entries(i).has_been_signal_updated || (!((s1_alloc || s1_do_dcache_evict) && s1_evict_valid)) && s1_update
432    valids(i) := valids(i) || alloc
433    entries(i) := Mux(alloc, s1_alloc_entry, Mux(update, update_entry, entries(i)))
434  }
435
436  val s1_update_entry = Mux1H(s1_update_mask, entries)
437  val s1_update_valid = Mux1H(s1_update_mask, valids)
438
439
440  when(s1_update){
441    assert(PopCount(s1_update_mask) === 1.U, "multi-agt-update")
442  }
443  when(s1_alloc){
444    assert(PopCount(s1_replace_mask) === 1.U, "multi-agt-alloc")
445  }
446
447  // pf_addr
448  // 1.hit => pf_addr = lookup_addr + (decr ? -1 : 1)
449  // 2.lookup region - 1 hit => lookup_addr + 1 (incr mode)
450  // 3.lookup region + 1 hit => lookup_addr - 1 (decr mode)
451  val s1_hited_entry_decr = Mux1H(s1_update_mask, entries.map(_.decr_mode))
452  val s1_pf_gen_decr_mode = Mux(s1_update,
453    s1_hited_entry_decr,
454    s1_agt_entry.decr_mode
455  )
456
457  val s1_pf_gen_vaddr_inc = Cat(0.U, s1_region_vaddr(REGION_TAG_WIDTH - 1, 0), s1_region_offset) + io.act_stride
458  val s1_pf_gen_vaddr_dec = Cat(0.U, s1_region_vaddr(REGION_TAG_WIDTH - 1, 0), s1_region_offset) - io.act_stride
459  val s1_vaddr_inc_cross_page = s1_pf_gen_vaddr_inc(BLOCK_ADDR_PAGE_BIT) =/= s1_region_vaddr(REGION_ADDR_PAGE_BIT)
460  val s1_vaddr_dec_cross_page = s1_pf_gen_vaddr_dec(BLOCK_ADDR_PAGE_BIT) =/= s1_region_vaddr(REGION_ADDR_PAGE_BIT)
461  val s1_vaddr_inc_cross_max_lim = s1_pf_gen_vaddr_inc.head(1).asBool
462  val s1_vaddr_dec_cross_max_lim = s1_pf_gen_vaddr_dec.head(1).asBool
463
464  //val s1_pf_gen_vaddr_p1 = s1_region_vaddr(REGION_TAG_WIDTH - 1, 0) + 1.U
465  //val s1_pf_gen_vaddr_m1 = s1_region_vaddr(REGION_TAG_WIDTH - 1, 0) - 1.U
466  val s1_pf_gen_vaddr = Cat(
467    s1_region_vaddr(REGION_ADDR_BITS - 1, REGION_TAG_WIDTH),
468    Mux(s1_pf_gen_decr_mode,
469      s1_pf_gen_vaddr_dec.tail(1).head(REGION_TAG_WIDTH),
470      s1_pf_gen_vaddr_inc.tail(1).head(REGION_TAG_WIDTH)
471    )
472  )
473  val s1_pf_gen_offset = Mux(s1_pf_gen_decr_mode,
474    s1_pf_gen_vaddr_dec(REGION_OFFSET - 1, 0),
475    s1_pf_gen_vaddr_inc(REGION_OFFSET - 1, 0)
476  )
477  val s1_pf_gen_offset_mask = UIntToOH(s1_pf_gen_offset)
478  val s1_pf_gen_access_cnt = Mux1H(s1_pf_gen_match_vec, entries.map(_.access_cnt))
479  val s1_in_active_page = s1_pf_gen_access_cnt > io.act_threshold
480  val s1_pf_gen_valid = prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && Mux(s1_pf_gen_decr_mode,
481    !s1_vaddr_dec_cross_max_lim,
482    !s1_vaddr_inc_cross_max_lim
483  ) && s1_in_active_page && io.agt_en
484  val s1_pf_gen_paddr_valid = Mux(s1_pf_gen_decr_mode, !s1_vaddr_dec_cross_page, !s1_vaddr_inc_cross_page)
485  val s1_pf_gen_region_addr = Mux(s1_pf_gen_paddr_valid,
486    Cat(s1_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT), s1_pf_gen_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)),
487    s1_pf_gen_vaddr
488  )
489  val s1_pf_gen_region_tag = region_hash_tag(s1_pf_gen_vaddr)
490  val s1_pf_gen_incr_region_bits = VecInit((0 until REGION_BLKS).map(i => {
491    if(i == 0) true.B else !s1_pf_gen_offset_mask(i - 1, 0).orR
492  })).asUInt
493  val s1_pf_gen_decr_region_bits = VecInit((0 until REGION_BLKS).map(i => {
494    if(i == REGION_BLKS - 1) true.B
495    else !s1_pf_gen_offset_mask(REGION_BLKS - 1, i + 1).orR
496  })).asUInt
497  val s1_pf_gen_region_bits = Mux(s1_pf_gen_decr_mode,
498    s1_pf_gen_decr_region_bits,
499    s1_pf_gen_incr_region_bits
500  )
501  val s1_pht_lookup_valid = Wire(Bool())
502  val s1_pht_lookup = Wire(new PhtLookup())
503
504  s1_pht_lookup_valid := !s1_pf_gen_valid && prev_lookup_valid
505  s1_pht_lookup.pht_index := s1_agt_entry.pht_index
506  s1_pht_lookup.pht_tag := s1_agt_entry.pht_tag
507  s1_pht_lookup.region_vaddr := s1_region_vaddr
508  s1_pht_lookup.region_paddr := s1_region_paddr
509  s1_pht_lookup.region_offset := s1_region_offset
510  s1_pht_lookup.region_bit_single := s1_bit_region_signal
511
512  io.s1_sel_stride := prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && !s1_in_active_page
513
514  // stage2: gen pf reg / evict entry to pht
515  // if no evict, update this time region bits to pht
516  val s2_do_dcache_evict = GatedValidRegNext(s1_do_dcache_evict, false.B)
517  val s1_send_update_entry = Mux((s1_alloc || s1_do_dcache_evict) && s1_evict_valid, s1_evict_entry, s1_update_entry)
518  val s2_evict_entry = RegEnable(s1_send_update_entry, s1_alloc || s1_do_dcache_evict || s1_update)
519  val s2_evict_valid = GatedValidRegNext(((s1_alloc || s1_do_dcache_evict) && s1_evict_valid) || s1_update, false.B)
520  val s2_update = RegNext(s1_update, false.B)
521  val s2_real_update = RegNext(((s1_alloc || s1_do_dcache_evict) && s1_evict_valid), false.B)
522  val s2_paddr_valid = RegEnable(s1_pf_gen_paddr_valid, s1_pf_gen_valid)
523  val s2_pf_gen_region_tag = RegEnable(s1_pf_gen_region_tag, s1_pf_gen_valid)
524  val s2_pf_gen_decr_mode = RegEnable(s1_pf_gen_decr_mode, s1_pf_gen_valid)
525  val s2_pf_gen_region_paddr = RegEnable(s1_pf_gen_region_addr, s1_pf_gen_valid)
526  val s2_pf_gen_alias_bits = RegEnable(get_alias_bits(s1_pf_gen_vaddr), s1_pf_gen_valid)
527  val s2_pf_gen_region_bits = RegEnable(s1_pf_gen_region_bits, s1_pf_gen_valid)
528  val s2_pf_gen_valid = GatedValidRegNext(s1_pf_gen_valid, false.B)
529  val s2_pht_lookup_valid = GatedValidRegNext(s1_pht_lookup_valid, false.B) && !io.s2_stride_hit
530  val s2_pht_lookup = RegEnable(s1_pht_lookup, s1_pht_lookup_valid)
531
532  io.s2_evict.valid := Mux(s2_real_update, s2_evict_valid && (s2_evict_entry.access_cnt > 1.U), s2_evict_valid)
533  io.s2_evict.bits := s2_evict_entry
534  io.s2_evict.bits.single_update := s2_update && (!s2_real_update)
535
536  io.s2_pf_gen_req.bits.region_tag := s2_pf_gen_region_tag
537  io.s2_pf_gen_req.bits.region_addr := s2_pf_gen_region_paddr
538  io.s2_pf_gen_req.bits.alias_bits := s2_pf_gen_alias_bits
539  io.s2_pf_gen_req.bits.region_bits := s2_pf_gen_region_bits
540  io.s2_pf_gen_req.bits.paddr_valid := s2_paddr_valid
541  io.s2_pf_gen_req.bits.decr_mode := s2_pf_gen_decr_mode
542  io.s2_pf_gen_req.valid := false.B
543  io.s2_pf_gen_req.bits.debug_source_type := HW_PREFETCH_AGT.U
544
545  io.s2_pht_lookup.valid := s2_pht_lookup_valid
546  io.s2_pht_lookup.bits := s2_pht_lookup
547
548  XSPerfAccumulate("sms_agt_in", io.s0_lookup.valid)
549  XSPerfAccumulate("sms_agt_alloc", s1_alloc) // cross region match or filter evict
550  XSPerfAccumulate("sms_agt_update", s1_update) // entry hit
551  XSPerfAccumulate("sms_agt_pf_gen", io.s2_pf_gen_req.valid)
552  XSPerfAccumulate("sms_agt_pf_gen_paddr_valid",
553    io.s2_pf_gen_req.valid && io.s2_pf_gen_req.bits.paddr_valid
554  )
555  XSPerfAccumulate("sms_agt_pf_gen_decr_mode",
556    io.s2_pf_gen_req.valid && io.s2_pf_gen_req.bits.decr_mode
557  )
558  for(i <- 0 until smsParams.active_gen_table_size){
559    XSPerfAccumulate(s"sms_agt_access_entry_$i",
560      s1_alloc && s1_replace_mask(i) || s1_update && s1_update_mask(i)
561    )
562  }
563  XSPerfAccumulate("sms_agt_evict", s2_evict_valid)
564  XSPerfAccumulate("sms_agt_evict_by_plru", s2_evict_valid && !s2_do_dcache_evict)
565  XSPerfAccumulate("sms_agt_evict_by_dcache", s2_evict_valid && s2_do_dcache_evict)
566  XSPerfAccumulate("sms_agt_evict_one_hot_pattern", s2_evict_valid && (s2_evict_entry.access_cnt === 1.U))
567}
568
569class PhtLookup()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
570  val pht_index = UInt(PHT_INDEX_BITS.W)
571  val pht_tag = UInt(PHT_TAG_BITS.W)
572  val region_paddr = UInt(REGION_ADDR_BITS.W)
573  val region_vaddr = UInt(REGION_ADDR_BITS.W)
574  val region_offset = UInt(REGION_OFFSET.W)
575  val region_bit_single = UInt(REGION_BLKS.W)
576}
577
578class PhtEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
579  val hist = Vec(2 * (REGION_BLKS - 1), UInt(PHT_HIST_BITS.W))
580  val tag = UInt(PHT_TAG_BITS.W)
581  val decr_mode = Bool()
582}
583
584class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
585  val io = IO(new Bundle() {
586    // receive agt evicted entry
587    val agt_update = Flipped(ValidIO(new AGTEntry()))
588    // at stage2, if we know agt missed, lookup pht
589    val s2_agt_lookup = Flipped(ValidIO(new PhtLookup()))
590    // pht-generated prefetch req
591    val pf_gen_req = ValidIO(new PfGenReq())
592  })
593
594  val pht_ram = Module(new SRAMTemplate[PhtEntry](new PhtEntry,
595    set = smsParams.pht_size / smsParams.pht_ways,
596    way =smsParams.pht_ways,
597    singlePort = true,
598    withClockGate = true,
599    hasMbist = hasMbist,
600    hasSramCtl = hasSramCtl
601  ))
602  def PHT_SETS = smsParams.pht_size / smsParams.pht_ways
603  // clockgated on pht_valids
604  val pht_valids_reg = RegInit(VecInit(Seq.fill(smsParams.pht_ways){
605    VecInit(Seq.fill(PHT_SETS){false.B})
606  }))
607  val pht_valids_enable = WireInit(VecInit(Seq.fill(PHT_SETS) {false.B}))
608  val pht_valids_next = WireInit(pht_valids_reg)
609  for(j <- 0 until PHT_SETS){
610    when(pht_valids_enable(j)){
611      (0 until smsParams.pht_ways).foreach(i => pht_valids_reg(i)(j) := pht_valids_next(i)(j))
612    }
613  }
614
615  val replacement = Seq.fill(PHT_SETS) { ReplacementPolicy.fromString("plru", smsParams.pht_ways) }
616
617  val lookup_queue = Module(new OverrideableQueue(new PhtLookup, smsParams.pht_lookup_queue_size))
618  lookup_queue.io.in := io.s2_agt_lookup
619  val lookup = lookup_queue.io.out
620
621  val evict_queue = Module(new OverrideableQueue(new AGTEntry, smsParams.pht_lookup_queue_size))
622  evict_queue.io.in := io.agt_update
623  val evict = evict_queue.io.out
624
625  XSPerfAccumulate("sms_pht_lookup_in", lookup_queue.io.in.fire)
626  XSPerfAccumulate("sms_pht_lookup_out", lookup_queue.io.out.fire)
627  XSPerfAccumulate("sms_pht_evict_in", evict_queue.io.in.fire)
628  XSPerfAccumulate("sms_pht_evict_out", evict_queue.io.out.fire)
629
630  val s3_ram_en = Wire(Bool())
631  val s1_valid = Wire(Bool())
632  // if s1.raddr == s2.waddr or s3 is using ram port, block s1
633  val s1_wait = Wire(Bool())
634  // pipe s0: select an op from [lookup, update], generate ram read addr
635  val s0_valid = lookup.valid || evict.valid
636
637  evict.ready := !s1_valid || !s1_wait
638  lookup.ready := evict.ready && !evict.valid
639
640  val s0_ram_raddr = Mux(evict.valid,
641    evict.bits.pht_index,
642    lookup.bits.pht_index
643  )
644  val s0_tag = Mux(evict.valid, evict.bits.pht_tag, lookup.bits.pht_tag)
645  val s0_region_offset = Mux(evict.valid, evict.bits.region_offset, lookup.bits.region_offset)
646  val s0_region_paddr = lookup.bits.region_paddr
647  val s0_region_vaddr = lookup.bits.region_vaddr
648  val s0_region_bits = evict.bits.region_bits
649  val s0_decr_mode = evict.bits.decr_mode
650  val s0_evict = evict.valid
651  val s0_access_cnt_signal = evict.bits.access_cnt
652  val s0_single_update = evict.bits.single_update
653  val s0_has_been_single_update = evict.bits.has_been_signal_updated
654  val s0_region_bit_single = evict.bits.region_bit_single
655
656  // pipe s1: send addr to ram
657  val s1_valid_r = RegInit(false.B)
658  s1_valid_r := Mux(s1_valid && s1_wait, true.B, s0_valid)
659  s1_valid := s1_valid_r
660  val s1_reg_en = s0_valid && (!s1_wait || !s1_valid)
661  val s1_ram_raddr = RegEnable(s0_ram_raddr, s1_reg_en)
662  val s1_tag = RegEnable(s0_tag, s1_reg_en)
663  val s1_access_cnt_signal = RegEnable(s0_access_cnt_signal, s1_reg_en)
664  val s1_region_bits = RegEnable(s0_region_bits, s1_reg_en)
665  val s1_decr_mode = RegEnable(s0_decr_mode, s1_reg_en)
666  val s1_region_paddr = RegEnable(s0_region_paddr, s1_reg_en)
667  val s1_region_vaddr = RegEnable(s0_region_vaddr, s1_reg_en)
668  val s1_region_offset = RegEnable(s0_region_offset, s1_reg_en)
669  val s1_single_update = RegEnable(s0_single_update, s1_reg_en)
670  val s1_has_been_single_update = RegEnable(s0_has_been_single_update, s1_reg_en)
671  val s1_region_bit_single = RegEnable(s0_region_bit_single, s1_reg_en)
672  val s1_pht_valids = pht_valids_reg.map(way => Mux1H(
673    (0 until PHT_SETS).map(i => i.U === s1_ram_raddr),
674    way
675  ))
676  val s1_evict = RegEnable(s0_evict, s1_reg_en)
677  val s1_replace_way = Mux1H(
678    (0 until PHT_SETS).map(i => i.U === s1_ram_raddr),
679    replacement.map(_.way)
680  )
681  val s1_hist_update_mask = Cat(
682    Fill(REGION_BLKS - 1, true.B), 0.U((REGION_BLKS - 1).W)
683  ) >> s1_region_offset
684  val s1_hist_bits = Cat(
685    s1_region_bits.head(REGION_BLKS - 1) >> s1_region_offset,
686    (Cat(
687      s1_region_bits.tail(1), 0.U((REGION_BLKS - 1).W)
688    ) >> s1_region_offset)(REGION_BLKS - 2, 0)
689  )
690  val s1_hist_single_bit = Cat(
691    s1_region_bit_single.head(REGION_BLKS - 1) >> s1_region_offset,
692    (Cat(
693      s1_region_bit_single.tail(1), 0.U((REGION_BLKS - 1).W)
694    ) >> s1_region_offset)(REGION_BLKS - 2, 0)
695  )
696
697  // pipe s2: generate ram write addr/data
698  val s2_valid = GatedValidRegNext(s1_valid && !s1_wait, false.B)
699  val s2_reg_en = s1_valid && !s1_wait
700  val s2_hist_update_mask = RegEnable(s1_hist_update_mask, s2_reg_en)
701  val s2_single_update = RegEnable(s1_single_update, s2_reg_en)
702  val s2_has_been_single_update = RegEnable(s1_has_been_single_update, s2_reg_en)
703  val s2_hist_bits = RegEnable(s1_hist_bits, s2_reg_en)
704  val s2_hist_bit_single = RegEnable(s1_hist_single_bit, s2_reg_en)
705  val s2_tag = RegEnable(s1_tag, s2_reg_en)
706  val s2_region_bits = RegEnable(s1_region_bits, s2_reg_en)
707  val s2_decr_mode = RegEnable(s1_decr_mode, s2_reg_en)
708  val s2_region_paddr = RegEnable(s1_region_paddr, s2_reg_en)
709  val s2_region_vaddr = RegEnable(s1_region_vaddr, s2_reg_en)
710  val s2_region_offset = RegEnable(s1_region_offset, s2_reg_en)
711  val s2_region_offset_mask = region_offset_to_bits(s2_region_offset)
712  val s2_evict = RegEnable(s1_evict, s2_reg_en)
713  val s2_pht_valids = s1_pht_valids.map(v => RegEnable(v, s2_reg_en))
714  val s2_replace_way = RegEnable(s1_replace_way, s2_reg_en)
715  val s2_ram_waddr = RegEnable(s1_ram_raddr, s2_reg_en)
716  val s2_ram_rdata = pht_ram.io.r.resp.data
717  val s2_ram_rtags = s2_ram_rdata.map(_.tag)
718  val s2_tag_match_vec = s2_ram_rtags.map(t => t === s2_tag)
719  val s2_access_cnt_signal = RegEnable(s1_access_cnt_signal, s2_reg_en)
720  val s2_hit_vec = s2_tag_match_vec.zip(s2_pht_valids).map({
721    case (tag_match, v) => v && tag_match
722  })
723
724  //distinguish single update and evict update
725  val s2_hist_update = s2_ram_rdata.map(way => VecInit(way.hist.zipWithIndex.map({
726    case (h, i) =>
727      val do_update = s2_hist_update_mask(i)
728      val hist_updated = Mux(!s2_single_update,
729                            Mux(s2_has_been_single_update,
730                              Mux(s2_hist_bits(i), h, Mux(h === 0.U, 0.U, h - 1.U)), Mux(s2_hist_bits(i),Mux(h.andR, h, h + 1.U), Mux(h === 0.U, 0.U, h - 1.U))),
731                                Mux(s2_hist_bit_single(i), Mux(h.andR, h, Mux(h===0.U, h+2.U, h+1.U)), h)
732                             )
733      Mux(do_update, hist_updated, h)
734  })))
735
736
737  val s2_hist_pf_gen = Mux1H(s2_hit_vec, s2_ram_rdata.map(way => VecInit(way.hist.map(_.head(1))).asUInt))
738  val s2_new_hist = VecInit(s2_hist_bits.asBools.map(b => Cat(0.U((PHT_HIST_BITS - 1).W), b)))
739  val s2_new_hist_single = VecInit(s2_hist_bit_single.asBools.map(b => Cat(0.U((PHT_HIST_BITS - 1).W), b)))
740  val s2_new_hist_real = Mux(s2_single_update,s2_new_hist_single,s2_new_hist)
741  val s2_pht_hit = Cat(s2_hit_vec).orR
742  // update when valid bits over 4
743  val signal_update_write = Mux(!s2_single_update, true.B, s2_pht_hit || s2_single_update && (s2_access_cnt_signal >4.U) )
744  val s2_hist = Mux(s2_pht_hit, Mux1H(s2_hit_vec, s2_hist_update), s2_new_hist_real)
745  val s2_repl_way_mask = UIntToOH(s2_replace_way)
746  val s2_incr_region_vaddr = s2_region_vaddr + 1.U
747  val s2_decr_region_vaddr = s2_region_vaddr - 1.U
748
749
750
751  // pipe s3: send addr/data to ram, gen pf_req
752  val s3_valid = GatedValidRegNext(s2_valid && signal_update_write, false.B)
753  val s3_evict = RegEnable(s2_evict, s2_valid)
754  val s3_hist = RegEnable(s2_hist, s2_valid)
755  val s3_hist_pf_gen = RegEnable(s2_hist_pf_gen, s2_valid)
756
757  val s3_hist_update_mask = RegEnable(s2_hist_update_mask.asUInt, s2_valid)
758
759  val s3_region_offset = RegEnable(s2_region_offset, s2_valid)
760  val s3_region_offset_mask = RegEnable(s2_region_offset_mask, s2_valid)
761  val s3_decr_mode = RegEnable(s2_decr_mode, s2_valid)
762  val s3_region_paddr = RegEnable(s2_region_paddr, s2_valid)
763  val s3_region_vaddr = RegEnable(s2_region_vaddr, s2_valid)
764  val s3_pht_tag = RegEnable(s2_tag, s2_valid)
765  val s3_hit_vec = s2_hit_vec.map(h => RegEnable(h, s2_valid))
766  val s3_hit = Cat(s3_hit_vec).orR
767  val s3_hit_way = OHToUInt(s3_hit_vec)
768  val s3_repl_way = RegEnable(s2_replace_way, s2_valid)
769  val s3_repl_way_mask = RegEnable(s2_repl_way_mask, s2_valid)
770  val s3_repl_update_mask = RegEnable(VecInit((0 until PHT_SETS).map(i => i.U === s2_ram_waddr)), s2_valid)
771  val s3_ram_waddr = RegEnable(s2_ram_waddr, s2_valid)
772  val s3_incr_region_vaddr = RegEnable(s2_incr_region_vaddr, s2_valid)
773  val s3_decr_region_vaddr = RegEnable(s2_decr_region_vaddr, s2_valid)
774  s3_ram_en := s3_valid && s3_evict
775  val s3_ram_wdata = Wire(new PhtEntry())
776  s3_ram_wdata.hist := s3_hist
777  s3_ram_wdata.tag := s3_pht_tag
778  s3_ram_wdata.decr_mode := s3_decr_mode
779
780  s1_wait := (s2_valid && s2_evict && s2_ram_waddr === s1_ram_raddr) || s3_ram_en
781
782  for((valids, way_idx) <- pht_valids_next.zipWithIndex){
783    val update_way = s3_repl_way_mask(way_idx)
784    for((v, set_idx) <- valids.zipWithIndex){
785      val update_set = s3_repl_update_mask(set_idx)
786      when(s3_valid && s3_evict && !s3_hit && update_set && update_way){
787        pht_valids_enable(set_idx) := true.B
788        v := true.B
789      }
790    }
791  }
792  for((r, i) <- replacement.zipWithIndex){
793    when(s3_valid && s3_repl_update_mask(i)){
794      when(s3_hit){
795        r.access(s3_hit_way)
796      }.elsewhen(s3_evict){
797        r.access(s3_repl_way)
798      }
799    }
800  }
801
802  val s3_way_mask = Mux(s3_hit,
803    VecInit(s3_hit_vec).asUInt,
804    s3_repl_way_mask,
805  ).asUInt
806
807  pht_ram.io.r(
808    s1_valid, s1_ram_raddr
809  )
810  pht_ram.io.w(
811    s3_ram_en, s3_ram_wdata, s3_ram_waddr, s3_way_mask
812  )
813  when(s3_valid && s3_hit){
814    assert(!Cat(s3_hit_vec).andR, "sms_pht: multi-hit!")
815  }
816
817  // generate pf req if hit
818  val s3_hist_hi = s3_hist_pf_gen.head(REGION_BLKS - 1)
819  val s3_hist_lo = s3_hist_pf_gen.tail(REGION_BLKS - 1)
820  val s3_hist_hi_shifted = (Cat(0.U((REGION_BLKS - 1).W), s3_hist_hi) << s3_region_offset)(2 * (REGION_BLKS - 1) - 1, 0)
821  val s3_hist_lo_shifted = (Cat(0.U((REGION_BLKS - 1).W), s3_hist_lo) << s3_region_offset)(2 * (REGION_BLKS - 1) - 1, 0)
822  val s3_cur_region_bits = Cat(s3_hist_hi_shifted.tail(REGION_BLKS - 1), 0.U(1.W)) |
823    Cat(0.U(1.W), s3_hist_lo_shifted.head(REGION_BLKS - 1))
824  val s3_incr_region_bits = Cat(0.U(1.W), s3_hist_hi_shifted.head(REGION_BLKS - 1))
825  val s3_decr_region_bits = Cat(s3_hist_lo_shifted.tail(REGION_BLKS - 1), 0.U(1.W))
826  val s3_pf_gen_valid = s3_valid && s3_hit && !s3_evict
827  val s3_cur_region_valid =  s3_pf_gen_valid && (s3_hist_pf_gen & s3_hist_update_mask).orR
828  val s3_incr_region_valid = s3_pf_gen_valid && (s3_hist_hi & (~s3_hist_update_mask.head(REGION_BLKS - 1)).asUInt).orR
829  val s3_decr_region_valid = s3_pf_gen_valid && (s3_hist_lo & (~s3_hist_update_mask.tail(REGION_BLKS - 1)).asUInt).orR
830  val s3_incr_alias_bits = get_alias_bits(s3_incr_region_vaddr)
831  val s3_decr_alias_bits = get_alias_bits(s3_decr_region_vaddr)
832  val s3_incr_region_paddr = Cat(
833    s3_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT),
834    s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)
835  )
836  val s3_decr_region_paddr = Cat(
837    s3_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT),
838    s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)
839  )
840  val s3_incr_crosspage = s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT)
841  val s3_decr_crosspage = s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT)
842  val s3_cur_region_tag = region_hash_tag(s3_region_vaddr)
843  val s3_incr_region_tag = region_hash_tag(s3_incr_region_vaddr)
844  val s3_decr_region_tag = region_hash_tag(s3_decr_region_vaddr)
845
846  val pf_gen_req_arb = Module(new Arbiter(new PfGenReq, 3))
847  val s4_pf_gen_cur_region_valid = RegInit(false.B)
848  val s4_pf_gen_cur_region = Reg(new PfGenReq)
849  val s4_pf_gen_incr_region_valid = RegInit(false.B)
850  val s4_pf_gen_incr_region = Reg(new PfGenReq)
851  val s4_pf_gen_decr_region_valid = RegInit(false.B)
852  val s4_pf_gen_decr_region = Reg(new PfGenReq)
853
854  s4_pf_gen_cur_region_valid := s3_cur_region_valid
855  when(s3_cur_region_valid){
856    s4_pf_gen_cur_region.region_addr := s3_region_paddr
857    s4_pf_gen_cur_region.alias_bits := get_alias_bits(s3_region_vaddr)
858    s4_pf_gen_cur_region.region_tag := s3_cur_region_tag
859    s4_pf_gen_cur_region.region_bits := s3_cur_region_bits
860    s4_pf_gen_cur_region.paddr_valid := true.B
861    s4_pf_gen_cur_region.decr_mode := false.B
862  }
863  s4_pf_gen_incr_region_valid := s3_incr_region_valid ||
864    (!pf_gen_req_arb.io.in(1).ready && s4_pf_gen_incr_region_valid)
865  when(s3_incr_region_valid){
866    s4_pf_gen_incr_region.region_addr := Mux(s3_incr_crosspage, s3_incr_region_vaddr, s3_incr_region_paddr)
867    s4_pf_gen_incr_region.alias_bits := s3_incr_alias_bits
868    s4_pf_gen_incr_region.region_tag := s3_incr_region_tag
869    s4_pf_gen_incr_region.region_bits := s3_incr_region_bits
870    s4_pf_gen_incr_region.paddr_valid := !s3_incr_crosspage
871    s4_pf_gen_incr_region.decr_mode := false.B
872  }
873  s4_pf_gen_decr_region_valid := s3_decr_region_valid ||
874    (!pf_gen_req_arb.io.in(2).ready && s4_pf_gen_decr_region_valid)
875  when(s3_decr_region_valid){
876    s4_pf_gen_decr_region.region_addr := Mux(s3_decr_crosspage, s3_decr_region_vaddr, s3_decr_region_paddr)
877    s4_pf_gen_decr_region.alias_bits := s3_decr_alias_bits
878    s4_pf_gen_decr_region.region_tag := s3_decr_region_tag
879    s4_pf_gen_decr_region.region_bits := s3_decr_region_bits
880    s4_pf_gen_decr_region.paddr_valid := !s3_decr_crosspage
881    s4_pf_gen_decr_region.decr_mode := true.B
882  }
883
884  pf_gen_req_arb.io.in.head.valid := s4_pf_gen_cur_region_valid
885  pf_gen_req_arb.io.in.head.bits := s4_pf_gen_cur_region
886  pf_gen_req_arb.io.in.head.bits.debug_source_type := HW_PREFETCH_PHT_CUR.U
887  pf_gen_req_arb.io.in(1).valid := s4_pf_gen_incr_region_valid
888  pf_gen_req_arb.io.in(1).bits := s4_pf_gen_incr_region
889  pf_gen_req_arb.io.in(1).bits.debug_source_type := HW_PREFETCH_PHT_INC.U
890  pf_gen_req_arb.io.in(2).valid := s4_pf_gen_decr_region_valid
891  pf_gen_req_arb.io.in(2).bits := s4_pf_gen_decr_region
892  pf_gen_req_arb.io.in(2).bits.debug_source_type := HW_PREFETCH_PHT_DEC.U
893  pf_gen_req_arb.io.out.ready := true.B
894
895  io.pf_gen_req.valid := pf_gen_req_arb.io.out.valid
896  io.pf_gen_req.bits := pf_gen_req_arb.io.out.bits
897
898  XSPerfAccumulate("sms_pht_update", io.agt_update.valid)
899  XSPerfAccumulate("sms_pht_update_hit", s2_valid && s2_evict && s2_pht_hit)
900  XSPerfAccumulate("sms_pht_lookup", io.s2_agt_lookup.valid)
901  XSPerfAccumulate("sms_pht_lookup_hit", s2_valid && !s2_evict && s2_pht_hit)
902  for(i <- 0 until smsParams.pht_ways){
903    XSPerfAccumulate(s"sms_pht_write_way_$i", pht_ram.io.w.req.fire && pht_ram.io.w.req.bits.waymask.get(i))
904  }
905  for(i <- 0 until PHT_SETS){
906    XSPerfAccumulate(s"sms_pht_write_set_$i", pht_ram.io.w.req.fire && pht_ram.io.w.req.bits.setIdx === i.U)
907  }
908  XSPerfAccumulate(s"sms_pht_pf_gen", io.pf_gen_req.valid)
909}
910
911class PrefetchFilterEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
912  val region_tag = UInt(REGION_TAG_WIDTH.W)
913  val region_addr = UInt(REGION_ADDR_BITS.W)
914  val region_bits = UInt(REGION_BLKS.W)
915  val filter_bits = UInt(REGION_BLKS.W)
916  val alias_bits = UInt(2.W)
917  val paddr_valid = Bool()
918  val decr_mode = Bool()
919  val debug_source_type = UInt(log2Up(nSourceType).W)
920}
921
922class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
923  val io = IO(new Bundle() {
924    val gen_req = Flipped(ValidIO(new PfGenReq()))
925    val tlb_req = new TlbRequestIO(2)
926    val pmp_resp = Flipped(new PMPRespBundle())
927    val l2_pf_addr = ValidIO(UInt(PAddrBits.W))
928    val pf_alias_bits = Output(UInt(2.W))
929    val debug_source_type = Output(UInt(log2Up(nSourceType).W))
930  })
931  val entries = Seq.fill(smsParams.pf_filter_size){ Reg(new PrefetchFilterEntry()) }
932  val valids = Seq.fill(smsParams.pf_filter_size){ RegInit(false.B) }
933  val replacement = ReplacementPolicy.fromString("plru", smsParams.pf_filter_size)
934
935  val prev_valid = GatedValidRegNext(io.gen_req.valid, false.B)
936  val prev_gen_req = RegEnable(io.gen_req.bits, io.gen_req.valid)
937
938  val tlb_req_arb = Module(new RRArbiterInit(new TlbReq, smsParams.pf_filter_size))
939  val pf_req_arb = Module(new RRArbiterInit(UInt(PAddrBits.W), smsParams.pf_filter_size))
940
941  io.l2_pf_addr.valid := pf_req_arb.io.out.valid
942  io.l2_pf_addr.bits := pf_req_arb.io.out.bits
943  io.pf_alias_bits := Mux1H(entries.zipWithIndex.map({
944    case (entry, i) => (i.U === pf_req_arb.io.chosen) -> entry.alias_bits
945  }))
946  pf_req_arb.io.out.ready := true.B
947
948  io.debug_source_type := VecInit(entries.map(_.debug_source_type))(pf_req_arb.io.chosen)
949
950  val s1_valid = Wire(Bool())
951  val s1_hit = Wire(Bool())
952  val s1_replace_vec = Wire(UInt(smsParams.pf_filter_size.W))
953  val s1_tlb_fire_vec = Wire(UInt(smsParams.pf_filter_size.W))
954  val s2_tlb_fire_vec = Wire(UInt(smsParams.pf_filter_size.W))
955  val s3_tlb_fire_vec = Wire(UInt(smsParams.pf_filter_size.W))
956  val not_tlbing_vec = VecInit((0 until smsParams.pf_filter_size).map{case i =>
957    !s1_tlb_fire_vec(i) && !s2_tlb_fire_vec(i) && !s3_tlb_fire_vec(i)
958  })
959
960  // s0: entries lookup
961  val s0_gen_req = io.gen_req.bits
962  val s0_match_prev = prev_valid && (s0_gen_req.region_tag === prev_gen_req.region_tag)
963  val s0_gen_req_valid = io.gen_req.valid && !s0_match_prev
964  val s0_match_vec = valids.indices.map(i => {
965    valids(i) && entries(i).region_tag === s0_gen_req.region_tag && !(s1_valid && !s1_hit && s1_replace_vec(i))
966  })
967  val s0_any_matched = Cat(s0_match_vec).orR
968  val s0_replace_vec = UIntToOH(replacement.way)
969  val s0_hit = s0_gen_req_valid && s0_any_matched
970
971  for(((v, ent), i) <- valids.zip(entries).zipWithIndex){
972    val is_evicted = s1_valid && s1_replace_vec(i)
973    tlb_req_arb.io.in(i).valid := v && not_tlbing_vec(i) && !ent.paddr_valid && !is_evicted
974    tlb_req_arb.io.in(i).bits.vaddr := Cat(ent.region_addr, 0.U(log2Up(REGION_SIZE).W))
975    tlb_req_arb.io.in(i).bits.cmd := TlbCmd.read
976    tlb_req_arb.io.in(i).bits.isPrefetch := true.B
977    tlb_req_arb.io.in(i).bits.size := 3.U
978    tlb_req_arb.io.in(i).bits.kill := false.B
979    tlb_req_arb.io.in(i).bits.no_translate := false.B
980    tlb_req_arb.io.in(i).bits.fullva := 0.U
981    tlb_req_arb.io.in(i).bits.checkfullva := false.B
982    tlb_req_arb.io.in(i).bits.memidx := DontCare
983    tlb_req_arb.io.in(i).bits.debug := DontCare
984    tlb_req_arb.io.in(i).bits.hlvx := DontCare
985    tlb_req_arb.io.in(i).bits.hyperinst := DontCare
986    tlb_req_arb.io.in(i).bits.pmp_addr := DontCare
987
988    val pending_req_vec = ent.region_bits & (~ent.filter_bits).asUInt
989    val first_one_offset = PriorityMux(
990      pending_req_vec.asBools,
991      (0 until smsParams.pf_filter_size).map(_.U(REGION_OFFSET.W))
992    )
993    val last_one_offset = PriorityMux(
994      pending_req_vec.asBools.reverse,
995      (0 until smsParams.pf_filter_size).reverse.map(_.U(REGION_OFFSET.W))
996    )
997    val pf_addr = Cat(
998      ent.region_addr,
999      Mux(ent.decr_mode, last_one_offset, first_one_offset),
1000      0.U(log2Up(dcacheParameters.blockBytes).W)
1001    )
1002    pf_req_arb.io.in(i).valid := v && Cat(pending_req_vec).orR && ent.paddr_valid && !is_evicted
1003    pf_req_arb.io.in(i).bits := pf_addr
1004  }
1005
1006  val s0_tlb_fire_vec = VecInit(tlb_req_arb.io.in.map(_.fire))
1007  val s0_pf_fire_vec = VecInit(pf_req_arb.io.in.map(_.fire))
1008
1009  val s0_update_way = OHToUInt(s0_match_vec)
1010  val s0_replace_way = replacement.way
1011  val s0_access_way = Mux(s0_any_matched, s0_update_way, s0_replace_way)
1012  when(s0_gen_req_valid){
1013    replacement.access(s0_access_way)
1014  }
1015
1016  // s1: update or alloc
1017  val s1_valid_r = GatedValidRegNext(s0_gen_req_valid, false.B)
1018  val s1_hit_r = RegEnable(s0_hit, false.B, s0_gen_req_valid)
1019  val s1_gen_req = RegEnable(s0_gen_req, s0_gen_req_valid)
1020  val s1_replace_vec_r = RegEnable(s0_replace_vec, s0_gen_req_valid && !s0_hit)
1021  val s1_update_vec = RegEnable(VecInit(s0_match_vec).asUInt, s0_gen_req_valid && s0_hit)
1022  val s1_tlb_fire_vec_r = GatedValidRegNext(s0_tlb_fire_vec)
1023  // tlb req will latch one cycle after tlb_arb
1024  val s1_tlb_req_valid = GatedValidRegNext(tlb_req_arb.io.out.fire)
1025  val s1_tlb_req_bits  = RegEnable(tlb_req_arb.io.out.bits, tlb_req_arb.io.out.fire)
1026  val s1_alloc_entry = Wire(new PrefetchFilterEntry())
1027  s1_valid := s1_valid_r
1028  s1_hit := s1_hit_r
1029  s1_replace_vec := s1_replace_vec_r
1030  s1_tlb_fire_vec := s1_tlb_fire_vec_r.asUInt
1031  s1_alloc_entry.region_tag := s1_gen_req.region_tag
1032  s1_alloc_entry.region_addr := s1_gen_req.region_addr
1033  s1_alloc_entry.region_bits := s1_gen_req.region_bits
1034  s1_alloc_entry.paddr_valid := s1_gen_req.paddr_valid
1035  s1_alloc_entry.decr_mode := s1_gen_req.decr_mode
1036  s1_alloc_entry.filter_bits := 0.U
1037  s1_alloc_entry.alias_bits := s1_gen_req.alias_bits
1038  s1_alloc_entry.debug_source_type := s1_gen_req.debug_source_type
1039  io.tlb_req.req.valid := s1_tlb_req_valid && !((s1_tlb_fire_vec & s1_replace_vec).orR && s1_valid && !s1_hit)
1040  io.tlb_req.req.bits := s1_tlb_req_bits
1041  io.tlb_req.resp.ready := true.B
1042  io.tlb_req.req_kill := false.B
1043  tlb_req_arb.io.out.ready := true.B
1044
1045  // s2: get response from tlb
1046  val s2_tlb_fire_vec_r = GatedValidRegNext(s1_tlb_fire_vec_r)
1047  s2_tlb_fire_vec := s2_tlb_fire_vec_r.asUInt
1048
1049  // s3: get pmp response form PMPChecker
1050  val s3_tlb_fire_vec_r = GatedValidRegNext(s2_tlb_fire_vec_r)
1051  val s3_tlb_resp_fire = RegNext(io.tlb_req.resp.fire)
1052  val s3_tlb_resp = RegEnable(io.tlb_req.resp.bits, io.tlb_req.resp.valid)
1053  val s3_pmp_resp = io.pmp_resp
1054  val s3_update_valid = s3_tlb_resp_fire && !s3_tlb_resp.miss
1055  val s3_drop = s3_update_valid && (
1056    // page/access fault
1057    s3_tlb_resp.excp.head.pf.ld || s3_tlb_resp.excp.head.gpf.ld || s3_tlb_resp.excp.head.af.ld ||
1058    // uncache
1059    s3_pmp_resp.mmio || Pbmt.isUncache(s3_tlb_resp.pbmt.head) ||
1060    // pmp access fault
1061    s3_pmp_resp.ld
1062  )
1063  s3_tlb_fire_vec := s3_tlb_fire_vec_r.asUInt
1064
1065  for(((v, ent), i) <- valids.zip(entries).zipWithIndex){
1066    val alloc = s1_valid && !s1_hit && s1_replace_vec(i)
1067    val update = s1_valid && s1_hit && s1_update_vec(i)
1068    // for pf: use s0 data
1069    val pf_fired = s0_pf_fire_vec(i)
1070    val tlb_fired = s3_tlb_fire_vec(i) && s3_update_valid
1071    when(tlb_fired){
1072      when(s3_drop){
1073        v := false.B
1074      }.otherwise{
1075        ent.paddr_valid := !s3_tlb_resp.miss
1076        ent.region_addr := region_addr(s3_tlb_resp.paddr.head)
1077      }
1078    }
1079    when(update){
1080      ent.region_bits := ent.region_bits | s1_gen_req.region_bits
1081    }
1082    when(pf_fired){
1083      val curr_bit = UIntToOH(block_addr(pf_req_arb.io.in(i).bits)(REGION_OFFSET - 1, 0))
1084      ent.filter_bits := ent.filter_bits | curr_bit
1085    }
1086    when(alloc){
1087      ent := s1_alloc_entry
1088      v := true.B
1089    }
1090  }
1091  when(s1_valid && s1_hit){
1092    assert(PopCount(s1_update_vec) === 1.U, "sms_pf_filter: multi-hit")
1093  }
1094  assert(!io.tlb_req.resp.fire || Cat(s2_tlb_fire_vec).orR, "sms_pf_filter: tlb resp fires, but no tlb req from tlb_req_arb 2 cycles ago")
1095
1096  XSPerfAccumulate("sms_pf_filter_recv_req", io.gen_req.valid)
1097  XSPerfAccumulate("sms_pf_filter_hit", s1_valid && s1_hit)
1098  XSPerfAccumulate("sms_pf_filter_tlb_req", io.tlb_req.req.fire)
1099  XSPerfAccumulate("sms_pf_filter_tlb_resp_miss", io.tlb_req.resp.fire && io.tlb_req.resp.bits.miss)
1100  XSPerfAccumulate("sms_pf_filter_tlb_resp_drop", s3_drop)
1101  XSPerfAccumulate("sms_pf_filter_tlb_resp_drop_by_pf_or_af",
1102    s3_update_valid && (s3_tlb_resp.excp.head.pf.ld || s3_tlb_resp.excp.head.gpf.ld || s3_tlb_resp.excp.head.af.ld)
1103  )
1104  XSPerfAccumulate("sms_pf_filter_tlb_resp_drop_by_uncache",
1105    s3_update_valid && (s3_pmp_resp.mmio || Pbmt.isUncache(s3_tlb_resp.pbmt.head))
1106  )
1107  XSPerfAccumulate("sms_pf_filter_tlb_resp_drop_by_pmp_af",
1108    s3_update_valid && (s3_pmp_resp.ld)
1109  )
1110  for(i <- 0 until smsParams.pf_filter_size){
1111    XSPerfAccumulate(s"sms_pf_filter_access_way_$i", s0_gen_req_valid && s0_access_way === i.U)
1112  }
1113  XSPerfAccumulate("sms_pf_filter_l2_req", io.l2_pf_addr.valid)
1114}
1115
1116class SMSTrainFilter()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper with HasTrainFilterHelper {
1117  val io = IO(new Bundle() {
1118    // train input
1119    // hybrid load store
1120    val ld_in = Flipped(Vec(backendParams.LdExuCnt, ValidIO(new LsPrefetchTrainBundle())))
1121    val st_in = Flipped(Vec(backendParams.StaExuCnt, ValidIO(new LsPrefetchTrainBundle())))
1122    // filter out
1123    val train_req = ValidIO(new PrefetchReqBundle())
1124  })
1125
1126  class Ptr(implicit p: Parameters) extends CircularQueuePtr[Ptr](
1127    p => smsParams.train_filter_size
1128  ){
1129  }
1130
1131  object Ptr {
1132    def apply(f: Bool, v: UInt)(implicit p: Parameters): Ptr = {
1133      val ptr = Wire(new Ptr)
1134      ptr.flag := f
1135      ptr.value := v
1136      ptr
1137    }
1138  }
1139
1140  val entries = RegInit(VecInit(Seq.fill(smsParams.train_filter_size){ (0.U.asTypeOf(new PrefetchReqBundle())) }))
1141  val valids = RegInit(VecInit(Seq.fill(smsParams.train_filter_size){ (false.B) }))
1142
1143  val enqLen = backendParams.LduCnt + backendParams.StaCnt
1144  val enqPtrExt = RegInit(VecInit((0 until enqLen).map(_.U.asTypeOf(new Ptr))))
1145  val deqPtrExt = RegInit(0.U.asTypeOf(new Ptr))
1146
1147  val deqPtr = WireInit(deqPtrExt.value)
1148
1149  require(smsParams.train_filter_size >= enqLen)
1150
1151  val ld_reorder = reorder(io.ld_in)
1152  val st_reorder = reorder(io.st_in)
1153  val reqs_ls = ld_reorder.map(_.bits.toPrefetchReqBundle()) ++ st_reorder.map(_.bits.toPrefetchReqBundle())
1154  val reqs_vls = ld_reorder.map(_.valid) ++ st_reorder.map(_.valid)
1155  val needAlloc = Wire(Vec(enqLen, Bool()))
1156  val canAlloc = Wire(Vec(enqLen, Bool()))
1157
1158  for(i <- (0 until enqLen)) {
1159    val req = reqs_ls(i)
1160    val req_v = reqs_vls(i)
1161    val index = PopCount(needAlloc.take(i))
1162    val allocPtr = enqPtrExt(index)
1163    val entry_match = Cat(entries.zip(valids).map {
1164      case(e, v) => v && block_hash_tag(e.vaddr) === block_hash_tag(req.vaddr)
1165    }).orR
1166    val prev_enq_match = if(i == 0) false.B else Cat(reqs_ls.zip(reqs_vls).take(i).map {
1167      case(pre, pre_v) => pre_v && block_hash_tag(pre.vaddr) === block_hash_tag(req.vaddr)
1168    }).orR
1169
1170    needAlloc(i) := req_v && !entry_match && !prev_enq_match
1171    canAlloc(i) := needAlloc(i) && allocPtr >= deqPtrExt
1172
1173    when(canAlloc(i)) {
1174      valids(allocPtr.value) := true.B
1175      entries(allocPtr.value) := req
1176    }
1177  }
1178  val allocNum = PopCount(canAlloc)
1179
1180  enqPtrExt.foreach{case x => when(canAlloc.asUInt.orR) {x := x + allocNum} }
1181
1182  io.train_req.valid := false.B
1183  io.train_req.bits := DontCare
1184  valids.zip(entries).zipWithIndex.foreach {
1185    case((valid, entry), i) => {
1186      when(deqPtr === i.U) {
1187        io.train_req.valid := valid
1188        io.train_req.bits := entry
1189      }
1190    }
1191  }
1192
1193  when(io.train_req.valid) {
1194    valids(deqPtr) := false.B
1195    deqPtrExt := deqPtrExt + 1.U
1196  }
1197
1198  XSPerfAccumulate("sms_train_filter_full", PopCount(valids) === (smsParams.train_filter_size).U)
1199  XSPerfAccumulate("sms_train_filter_half", PopCount(valids) >= (smsParams.train_filter_size / 2).U)
1200  XSPerfAccumulate("sms_train_filter_empty", PopCount(valids) === 0.U)
1201
1202  val raw_enq_pattern = Cat(reqs_vls)
1203  val filtered_enq_pattern = Cat(needAlloc)
1204  val actual_enq_pattern = Cat(canAlloc)
1205  XSPerfAccumulate("sms_train_filter_enq", allocNum > 0.U)
1206  XSPerfAccumulate("sms_train_filter_deq", io.train_req.fire)
1207  def toBinary(n: Int): String = n match {
1208    case 0|1 => s"$n"
1209    case _   => s"${toBinary(n/2)}${n%2}"
1210  }
1211  for(i <- 0 until (1 << enqLen)) {
1212    XSPerfAccumulate(s"sms_train_filter_raw_enq_pattern_${toBinary(i)}", raw_enq_pattern === i.U)
1213    XSPerfAccumulate(s"sms_train_filter_filtered_enq_pattern_${toBinary(i)}", filtered_enq_pattern === i.U)
1214    XSPerfAccumulate(s"sms_train_filter_actual_enq_pattern_${toBinary(i)}", actual_enq_pattern === i.U)
1215  }
1216}
1217
1218class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSModuleHelper with HasL1PrefetchSourceParameter {
1219  import freechips.rocketchip.util._
1220
1221  val io_agt_en = IO(Input(Bool()))
1222  val io_stride_en = IO(Input(Bool()))
1223  val io_pht_en = IO(Input(Bool()))
1224  val io_act_threshold = IO(Input(UInt(REGION_OFFSET.W)))
1225  val io_act_stride = IO(Input(UInt(6.W)))
1226  val io_dcache_evict = IO(Flipped(DecoupledIO(new AGTEvictReq)))
1227
1228  val train_filter = Module(new SMSTrainFilter)
1229
1230  train_filter.io.ld_in <> io.ld_in
1231  train_filter.io.st_in <> io.st_in
1232
1233  val train_ld = train_filter.io.train_req.bits
1234
1235  val train_block_tag = block_hash_tag(train_ld.vaddr)
1236  val train_region_tag = train_block_tag.head(REGION_TAG_WIDTH)
1237
1238  val train_region_addr_raw = region_addr(train_ld.vaddr)(REGION_TAG_WIDTH + 2 * VADDR_HASH_WIDTH - 1, 0)
1239  val train_region_addr_p1 = Cat(0.U(1.W), train_region_addr_raw) + 1.U
1240  val train_region_addr_m1 = Cat(0.U(1.W), train_region_addr_raw) - 1.U
1241  // addr_p1 or addr_m1 is valid?
1242  val train_allow_cross_region_p1 = !train_region_addr_p1.head(1).asBool
1243  val train_allow_cross_region_m1 = !train_region_addr_m1.head(1).asBool
1244
1245  val train_region_p1_tag = region_hash_tag(train_region_addr_p1.tail(1))
1246  val train_region_m1_tag = region_hash_tag(train_region_addr_m1.tail(1))
1247
1248  val train_region_p1_cross_page = page_bit(train_region_addr_p1) ^ page_bit(train_region_addr_raw)
1249  val train_region_m1_cross_page = page_bit(train_region_addr_m1) ^ page_bit(train_region_addr_raw)
1250
1251  val train_region_paddr = region_addr(train_ld.paddr)
1252  val train_region_vaddr = region_addr(train_ld.vaddr)
1253  val train_region_offset = train_block_tag(REGION_OFFSET - 1, 0)
1254  val train_vld = train_filter.io.train_req.valid
1255
1256
1257  // prefetch stage0
1258  val active_gen_table = Module(new ActiveGenerationTable())
1259  val stride = Module(new StridePF())
1260  val pht = Module(new PatternHistoryTable())
1261  val pf_filter = Module(new PrefetchFilter())
1262
1263  val train_vld_s0 = GatedValidRegNext(train_vld, false.B)
1264  val train_s0 = RegEnable(train_ld, train_vld)
1265  val train_region_tag_s0 = RegEnable(train_region_tag, train_vld)
1266  val train_region_p1_tag_s0 = RegEnable(train_region_p1_tag, train_vld)
1267  val train_region_m1_tag_s0 = RegEnable(train_region_m1_tag, train_vld)
1268  val train_allow_cross_region_p1_s0 = RegEnable(train_allow_cross_region_p1, train_vld)
1269  val train_allow_cross_region_m1_s0 = RegEnable(train_allow_cross_region_m1, train_vld)
1270  val train_pht_tag_s0 = RegEnable(pht_tag(train_ld.pc), train_vld)
1271  val train_pht_index_s0 = RegEnable(pht_index(train_ld.pc), train_vld)
1272  val train_region_offset_s0 = RegEnable(train_region_offset, train_vld)
1273  val train_region_p1_cross_page_s0 = RegEnable(train_region_p1_cross_page, train_vld)
1274  val train_region_m1_cross_page_s0 = RegEnable(train_region_m1_cross_page, train_vld)
1275  val train_region_paddr_s0 = RegEnable(train_region_paddr, train_vld)
1276  val train_region_vaddr_s0 = RegEnable(train_region_vaddr, train_vld)
1277
1278  active_gen_table.io.agt_en := io_agt_en
1279  active_gen_table.io.act_threshold := io_act_threshold
1280  active_gen_table.io.act_stride := io_act_stride
1281  active_gen_table.io.s0_lookup.valid := train_vld_s0
1282  active_gen_table.io.s0_lookup.bits.region_tag := train_region_tag_s0
1283  active_gen_table.io.s0_lookup.bits.region_p1_tag := train_region_p1_tag_s0
1284  active_gen_table.io.s0_lookup.bits.region_m1_tag := train_region_m1_tag_s0
1285  active_gen_table.io.s0_lookup.bits.region_offset := train_region_offset_s0
1286  active_gen_table.io.s0_lookup.bits.pht_index := train_pht_index_s0
1287  active_gen_table.io.s0_lookup.bits.pht_tag := train_pht_tag_s0
1288  active_gen_table.io.s0_lookup.bits.allow_cross_region_p1 := train_allow_cross_region_p1_s0
1289  active_gen_table.io.s0_lookup.bits.allow_cross_region_m1 := train_allow_cross_region_m1_s0
1290  active_gen_table.io.s0_lookup.bits.region_p1_cross_page := train_region_p1_cross_page_s0
1291  active_gen_table.io.s0_lookup.bits.region_m1_cross_page := train_region_m1_cross_page_s0
1292  active_gen_table.io.s0_lookup.bits.region_paddr := train_region_paddr_s0
1293  active_gen_table.io.s0_lookup.bits.region_vaddr := train_region_vaddr_s0
1294  active_gen_table.io.s2_stride_hit := stride.io.s2_gen_req.valid
1295  active_gen_table.io.s0_dcache_evict <> io_dcache_evict
1296
1297  stride.io.stride_en := io_stride_en
1298  stride.io.s0_lookup.valid := train_vld_s0
1299  stride.io.s0_lookup.bits.pc := train_s0.pc(STRIDE_PC_BITS - 1, 0)
1300  stride.io.s0_lookup.bits.vaddr := Cat(
1301    train_region_vaddr_s0, train_region_offset_s0, 0.U(log2Up(dcacheParameters.blockBytes).W)
1302  )
1303  stride.io.s0_lookup.bits.paddr := Cat(
1304    train_region_paddr_s0, train_region_offset_s0, 0.U(log2Up(dcacheParameters.blockBytes).W)
1305  )
1306  stride.io.s1_valid := active_gen_table.io.s1_sel_stride
1307
1308  pht.io.s2_agt_lookup := active_gen_table.io.s2_pht_lookup
1309  pht.io.agt_update := active_gen_table.io.s2_evict
1310
1311  val pht_gen_valid = pht.io.pf_gen_req.valid && io_pht_en
1312  val agt_gen_valid = active_gen_table.io.s2_pf_gen_req.valid
1313  val stride_gen_valid = stride.io.s2_gen_req.valid
1314  val pf_gen_req = Mux(agt_gen_valid || stride_gen_valid,
1315    Mux1H(Seq(
1316      agt_gen_valid -> active_gen_table.io.s2_pf_gen_req.bits,
1317      stride_gen_valid -> stride.io.s2_gen_req.bits
1318    )),
1319    pht.io.pf_gen_req.bits
1320  )
1321  assert(!(agt_gen_valid && stride_gen_valid))
1322  pf_filter.io.gen_req.valid := pht_gen_valid || agt_gen_valid || stride_gen_valid
1323  pf_filter.io.gen_req.bits := pf_gen_req
1324  io.tlb_req <> pf_filter.io.tlb_req
1325  pf_filter.io.pmp_resp := io.pmp_resp
1326  val is_valid_address = PmemRanges.map(_.cover(pf_filter.io.l2_pf_addr.bits)).reduce(_ || _)
1327
1328  io.l2_req.valid := pf_filter.io.l2_pf_addr.valid && io.enable && is_valid_address
1329  io.l2_req.bits.addr := pf_filter.io.l2_pf_addr.bits
1330  io.l2_req.bits.source := MemReqSource.Prefetch2L2SMS.id.U
1331
1332  // for now, sms will not send l1 prefetch requests
1333  io.l1_req.bits.paddr := pf_filter.io.l2_pf_addr.bits
1334  io.l1_req.bits.alias := pf_filter.io.pf_alias_bits
1335  io.l1_req.bits.is_store := true.B
1336  io.l1_req.bits.confidence := 1.U
1337  io.l1_req.bits.pf_source.value := L1_HW_PREFETCH_NULL
1338  io.l1_req.valid := false.B
1339
1340  for((train, i) <- io.ld_in.zipWithIndex){
1341    XSPerfAccumulate(s"pf_train_miss_${i}", train.valid && train.bits.miss)
1342    XSPerfAccumulate(s"pf_train_prefetched_${i}", train.valid && isFromL1Prefetch(train.bits.meta_prefetch))
1343  }
1344  val trace = Wire(new L1MissTrace)
1345  trace.vaddr := 0.U
1346  trace.pc := 0.U
1347  trace.paddr := io.l2_req.bits.addr
1348  trace.source := pf_filter.io.debug_source_type
1349  val table = ChiselDB.createTable("L1SMSMissTrace_hart"+ p(XSCoreParamsKey).HartId.toString, new L1MissTrace)
1350  table.log(trace, io.l2_req.fire, "SMSPrefetcher", clock, reset)
1351
1352  XSPerfAccumulate("sms_pf_gen_conflict",
1353    pht_gen_valid && agt_gen_valid
1354  )
1355  XSPerfAccumulate("sms_pht_disabled", pht.io.pf_gen_req.valid && !io_pht_en)
1356  XSPerfAccumulate("sms_agt_disabled", active_gen_table.io.s2_pf_gen_req.valid && !io_agt_en)
1357  XSPerfAccumulate("sms_pf_real_issued", io.l2_req.valid)
1358  XSPerfAccumulate("sms_l1_req_valid", io.l1_req.valid)
1359  XSPerfAccumulate("sms_l1_req_fire", io.l1_req.fire)
1360}
1361