xref: /XiangShan/src/main/scala/xiangshan/mem/prefetch/L1StreamPrefetcher.scala (revision f7063a43ab34da917ba6c670d21871314340c550)
1package xiangshan.mem.prefetch
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import chisel3.util._
6import xiangshan._
7import utils._
8import utility._
9import xiangshan.cache.HasDCacheParameters
10import xiangshan.cache.mmu._
11import xiangshan.mem.{L1PrefetchReq, LdPrefetchTrainBundle}
12import xiangshan.mem.trace._
13import xiangshan.mem.L1PrefetchSource
14
15trait HasStreamPrefetchHelper extends HasL1PrefetchHelper {
16  // capacity related
17  val STREAM_FILTER_SIZE = 4
18  val BIT_VEC_ARRAY_SIZE = 16
19  val ACTIVE_THRESHOLD = BIT_VEC_WITDH - 4
20  val INIT_DEC_MODE = false
21
22  // bit_vector [StreamBitVectorBundle]:
23  // `X`: valid; `.`: invalid; `H`: hit
24  // [X X X X X X X X X . . H . X X X]                                                         [. . X X X X . . . . . . . . . .]
25  //                    hit in 12th slot & active           --------------------->             prefetch bit_vector [StreamPrefetchReqBundle]
26  //                        |  <---------------------------- depth ---------------------------->
27  //                                                                                           | <-- width -- >
28  val DEPTH_BYTES = 1024
29  val DEPTH_CACHE_BLOCKS = DEPTH_BYTES / dcacheParameters.blockBytes
30  val WIDTH_BYTES = 128
31  val WIDTH_CACHE_BLOCKS = WIDTH_BYTES / dcacheParameters.blockBytes
32
33  val L2_DEPTH_RATIO = 2
34  val L2_WIDTH_BYTES = WIDTH_BYTES * 2
35  val L2_WIDTH_CACHE_BLOCKS = L2_WIDTH_BYTES / dcacheParameters.blockBytes
36
37  val L3_DEPTH_RATIO = 3
38  val L3_WIDTH_BYTES = WIDTH_BYTES * 2 * 2
39  val L3_WIDTH_CACHE_BLOCKS = L3_WIDTH_BYTES / dcacheParameters.blockBytes
40
41  val DEPTH_LOOKAHEAD = 6
42  val DEPTH_BITS = log2Up(DEPTH_CACHE_BLOCKS) + DEPTH_LOOKAHEAD
43
44  val ENABLE_DECR_MODE = false
45  val ENABLE_STRICT_ACTIVE_DETECTION = true
46
47  // constraints
48  require((DEPTH_BYTES >= REGION_SIZE) && ((DEPTH_BYTES % REGION_SIZE) == 0) && ((DEPTH_BYTES / REGION_SIZE) > 0))
49  require(((VADDR_HASH_WIDTH * 3) + BLK_ADDR_RAW_WIDTH) <= REGION_TAG_BITS)
50  require(WIDTH_BYTES >= dcacheParameters.blockBytes)
51}
52
53class StreamBitVectorBundle(implicit p: Parameters) extends XSBundle with HasStreamPrefetchHelper {
54  val tag = UInt(REGION_TAG_BITS.W)
55  val bit_vec = UInt(BIT_VEC_WITDH.W)
56  val active = Bool()
57  // cnt can be optimized
58  val cnt = UInt((log2Up(BIT_VEC_WITDH) + 1).W)
59  val decr_mode = Bool()
60
61  def reset(index: Int) = {
62    tag := index.U
63    bit_vec := 0.U
64    active := false.B
65    cnt := 0.U
66    decr_mode := INIT_DEC_MODE.B
67  }
68
69  def tag_match(new_tag: UInt): Bool = {
70    region_hash_tag(tag) === region_hash_tag(new_tag)
71  }
72
73  def alloc(alloc_tag: UInt, alloc_bit_vec: UInt, alloc_active: Bool, alloc_decr_mode: Bool) = {
74    tag := alloc_tag
75    bit_vec := alloc_bit_vec
76    active := alloc_active
77    cnt := 1.U
78    if(ENABLE_DECR_MODE) {
79      decr_mode := alloc_decr_mode
80    }else {
81      decr_mode := INIT_DEC_MODE.B
82    }
83
84    assert(PopCount(alloc_bit_vec) === 1.U, "alloc vector should be one hot")
85  }
86
87  def update(update_bit_vec: UInt, update_active: Bool) = {
88    // if the slot is 0 before, increment cnt
89    val cnt_en = !((bit_vec & update_bit_vec).orR)
90    val cnt_next = Mux(cnt_en, cnt + 1.U, cnt)
91
92    bit_vec := bit_vec | update_bit_vec
93    cnt := cnt_next
94    when(cnt_next >= ACTIVE_THRESHOLD.U) {
95      active := true.B
96    }
97    when(update_active) {
98      active := true.B
99    }
100
101    assert(PopCount(update_bit_vec) === 1.U, "update vector should be one hot")
102    assert(cnt <= BIT_VEC_WITDH.U, "cnt should always less than bit vector size")
103  }
104}
105
106class StreamPrefetchReqBundle(implicit p: Parameters) extends XSBundle with HasStreamPrefetchHelper {
107  val region = UInt(REGION_TAG_BITS.W)
108  val bit_vec = UInt(BIT_VEC_WITDH.W)
109  val sink = UInt(SINK_BITS.W)
110  val source = new L1PrefetchSource()
111
112  // align prefetch vaddr and width to region
113  def getStreamPrefetchReqBundle(valid: Bool, vaddr: UInt, width: Int, decr_mode: Bool, sink: UInt, source: UInt): StreamPrefetchReqBundle = {
114    val res = Wire(new StreamPrefetchReqBundle)
115    res.region := get_region_tag(vaddr)
116    res.sink := sink
117    res.source.value := source
118
119    val region_bits = get_region_bits(vaddr)
120    val region_bit_vec = UIntToOH(region_bits)
121    res.bit_vec := Mux(
122      decr_mode,
123      (0 until width).map{ case i => region_bit_vec >> i}.reduce(_ | _),
124      (0 until width).map{ case i => region_bit_vec << i}.reduce(_ | _)
125    )
126
127    assert(!valid || PopCount(res.bit_vec) <= width.U, "actual prefetch block number should less than or equals to WIDTH_CACHE_BLOCKS")
128    assert(!valid || PopCount(res.bit_vec) >= 1.U, "at least one block should be included")
129    assert(sink <= SINK_L3, "invalid sink")
130    for(i <- 0 until BIT_VEC_WITDH) {
131      when(decr_mode) {
132        when(i.U > region_bits) {
133          assert(!valid || res.bit_vec(i) === 0.U, s"res.bit_vec(${i}) is not zero in decr_mode, prefetch vector is wrong!")
134        }.elsewhen(i.U === region_bits) {
135          assert(!valid || res.bit_vec(i) === 1.U, s"res.bit_vec(${i}) is zero in decr_mode, prefetch vector is wrong!")
136        }
137      }.otherwise {
138        when(i.U < region_bits) {
139          assert(!valid || res.bit_vec(i) === 0.U, s"res.bit_vec(${i}) is not zero in incr_mode, prefetch vector is wrong!")
140        }.elsewhen(i.U === region_bits) {
141          assert(!valid || res.bit_vec(i) === 1.U, s"res.bit_vec(${i}) is zero in decr_mode, prefetch vector is wrong!")
142        }
143      }
144    }
145
146    res
147  }
148}
149
150class StreamBitVectorArray(implicit p: Parameters) extends XSModule with HasStreamPrefetchHelper {
151  val io = IO(new XSBundle {
152    val enable = Input(Bool())
153    // TODO: flush all entry when process changing happens, or disable stream prefetch for a while
154    val flush = Input(Bool())
155    val dynamic_depth = Input(UInt(DEPTH_BITS.W))
156    val train_req = Flipped(DecoupledIO(new PrefetchReqBundle))
157    val prefetch_req = ValidIO(new StreamPrefetchReqBundle)
158
159    // Stride send lookup req here
160    val stream_lookup_req  = Flipped(ValidIO(new PrefetchReqBundle))
161    val stream_lookup_resp = Output(Bool())
162  })
163
164  val array = Reg(Vec(BIT_VEC_ARRAY_SIZE, new StreamBitVectorBundle))
165  val replacement = ReplacementPolicy.fromString("plru", BIT_VEC_ARRAY_SIZE)
166
167  // s0: generate region tag, parallel match
168  val s0_can_accept = Wire(Bool())
169  val s0_valid = io.train_req.fire
170  val s0_vaddr = io.train_req.bits.vaddr
171  val s0_region_bits = get_region_bits(s0_vaddr)
172  val s0_region_tag = get_region_tag(s0_vaddr)
173  val s0_region_tag_plus_one = get_region_tag(s0_vaddr) + 1.U
174  val s0_region_tag_minus_one = get_region_tag(s0_vaddr) - 1.U
175  val s0_region_tag_match_vec = array.map(_.tag_match(s0_region_tag))
176  val s0_region_tag_plus_one_match_vec = array.map(_.tag_match(s0_region_tag_plus_one))
177  val s0_region_tag_minus_one_match_vec = array.map(_.tag_match(s0_region_tag_minus_one))
178  val s0_hit = Cat(s0_region_tag_match_vec).orR
179  val s0_plus_one_hit = Cat(s0_region_tag_plus_one_match_vec).orR
180  val s0_minus_one_hit = Cat(s0_region_tag_minus_one_match_vec).orR
181  val s0_hit_vec = VecInit(s0_region_tag_match_vec).asUInt
182  val s0_index = Mux(s0_hit, OHToUInt(s0_hit_vec), replacement.way)
183  val s0_plus_one_index = OHToUInt(VecInit(s0_region_tag_plus_one_match_vec).asUInt)
184  val s0_minus_one_index = OHToUInt(VecInit(s0_region_tag_minus_one_match_vec).asUInt)
185  io.train_req.ready := s0_can_accept
186
187  when(s0_valid) {
188    replacement.access(s0_index)
189  }
190
191  assert(!s0_valid || PopCount(VecInit(s0_region_tag_match_vec)) <= 1.U, "req region should match no more than 1 entry")
192  assert(!s0_valid || PopCount(VecInit(s0_region_tag_plus_one_match_vec)) <= 1.U, "req region plus 1 should match no more than 1 entry")
193  assert(!s0_valid || PopCount(VecInit(s0_region_tag_minus_one_match_vec)) <= 1.U, "req region minus 1 should match no more than 1 entry")
194  assert(!s0_valid || !(s0_hit && s0_plus_one_hit && (s0_index === s0_plus_one_index)), "region and region plus 1 index match failed")
195  assert(!s0_valid || !(s0_hit && s0_minus_one_hit && (s0_index === s0_minus_one_index)), "region and region minus 1 index match failed")
196  assert(!s0_valid || !(s0_plus_one_hit && s0_minus_one_hit && (s0_minus_one_index === s0_plus_one_index)), "region plus 1 and region minus 1 index match failed")
197  assert(!(s0_valid && RegNext(s0_valid) && !s0_hit && !RegNext(s0_hit) && replacement.way === RegNext(replacement.way)), "replacement error")
198
199  XSPerfAccumulate("s0_valid_train_req", s0_valid)
200  val s0_hit_pattern_vec = Seq(s0_hit, s0_plus_one_hit, s0_minus_one_hit)
201  for(i <- 0 until (1 << s0_hit_pattern_vec.size)) {
202    XSPerfAccumulate(s"s0_hit_pattern_${toBinary(i)}", (VecInit(s0_hit_pattern_vec).asUInt === i.U) && s0_valid)
203  }
204  XSPerfAccumulate("s0_replace_the_neighbor", s0_valid && !s0_hit && ((s0_plus_one_hit && (s0_index === s0_plus_one_index)) || (s0_minus_one_hit && (s0_index === s0_minus_one_index))))
205  XSPerfAccumulate("s0_req_valid", io.train_req.valid)
206  XSPerfAccumulate("s0_req_cannot_accept", io.train_req.valid && !io.train_req.ready)
207
208  val ratio_const = WireInit(Constantin.createRecord("l2DepthRatio" + p(XSCoreParamsKey).HartId.toString, initValue = L2_DEPTH_RATIO.U))
209  val ratio = ratio_const(3, 0)
210
211  val l3_ratio_const = WireInit(Constantin.createRecord("l3DepthRatio" + p(XSCoreParamsKey).HartId.toString, initValue = L3_DEPTH_RATIO.U))
212  val l3_ratio = l3_ratio_const(3, 0)
213
214  // s1: alloc or update
215  val s1_valid = RegNext(s0_valid)
216  val s1_index = RegEnable(s0_index, s0_valid)
217  val s1_plus_one_index = RegEnable(s0_plus_one_index, s0_valid)
218  val s1_minus_one_index = RegEnable(s0_minus_one_index, s0_valid)
219  val s1_hit = RegEnable(s0_hit, s0_valid)
220  val s1_plus_one_hit = if(ENABLE_STRICT_ACTIVE_DETECTION)
221                            RegEnable(s0_plus_one_hit, s0_valid) && array(s1_plus_one_index).active && (array(s1_plus_one_index).cnt >= ACTIVE_THRESHOLD.U)
222                        else
223                            RegEnable(s0_plus_one_hit, s0_valid) && array(s1_plus_one_index).active
224  val s1_minus_one_hit = if(ENABLE_STRICT_ACTIVE_DETECTION)
225                            RegEnable(s0_minus_one_hit, s0_valid) && array(s1_minus_one_index).active && (array(s1_minus_one_index).cnt >= ACTIVE_THRESHOLD.U)
226                        else
227                            RegEnable(s0_minus_one_hit, s0_valid) && array(s1_minus_one_index).active
228  val s1_region_tag = RegEnable(s0_region_tag, s0_valid)
229  val s1_region_bits = RegEnable(s0_region_bits, s0_valid)
230  val s1_alloc = s1_valid && !s1_hit
231  val s1_update = s1_valid && s1_hit
232  val s1_pf_l1_incr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) + io.dynamic_depth, 0.U(BLOCK_OFFSET.W))
233  val s1_pf_l1_decr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) - io.dynamic_depth, 0.U(BLOCK_OFFSET.W))
234  val s1_pf_l2_incr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) + (io.dynamic_depth << ratio), 0.U(BLOCK_OFFSET.W))
235  val s1_pf_l2_decr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) - (io.dynamic_depth << ratio), 0.U(BLOCK_OFFSET.W))
236  val s1_pf_l3_incr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) + (io.dynamic_depth << l3_ratio), 0.U(BLOCK_OFFSET.W))
237  val s1_pf_l3_decr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) - (io.dynamic_depth << l3_ratio), 0.U(BLOCK_OFFSET.W))
238  // TODO: remove this
239  val s1_can_send_pf = Mux(s1_update, !((array(s1_index).bit_vec & UIntToOH(s1_region_bits)).orR), true.B)
240  s0_can_accept := !(s1_valid && (region_hash_tag(s1_region_tag) === region_hash_tag(s0_region_tag)))
241
242  when(s1_alloc) {
243    // alloc a new entry
244    array(s1_index).alloc(
245      alloc_tag = s1_region_tag,
246      alloc_bit_vec = UIntToOH(s1_region_bits),
247      alloc_active = s1_plus_one_hit || s1_minus_one_hit,
248      alloc_decr_mode = RegEnable(s0_plus_one_hit, s0_valid))
249
250  }.elsewhen(s1_update) {
251    // update a existing entry
252    assert(array(s1_index).cnt =/= 0.U || array(s1_index).tag === s1_index, "entry should have been allocated before")
253    array(s1_index).update(
254      update_bit_vec = UIntToOH(s1_region_bits),
255      update_active = s1_plus_one_hit || s1_minus_one_hit)
256  }
257
258  XSPerfAccumulate("s1_alloc", s1_alloc)
259  XSPerfAccumulate("s1_update", s1_update)
260  XSPerfAccumulate("s1_active_plus_one_hit", s1_valid && s1_plus_one_hit)
261  XSPerfAccumulate("s1_active_minus_one_hit", s1_valid && s1_minus_one_hit)
262
263  // s2: trigger prefetch if hit active bit vector, compute meta of prefetch req
264  val s2_valid = RegNext(s1_valid)
265  val s2_index = RegEnable(s1_index, s1_valid)
266  val s2_region_bits = RegEnable(s1_region_bits, s1_valid)
267  val s2_region_tag = RegEnable(s1_region_tag, s1_valid)
268  val s2_pf_l1_incr_vaddr = RegEnable(s1_pf_l1_incr_vaddr, s1_valid)
269  val s2_pf_l1_decr_vaddr = RegEnable(s1_pf_l1_decr_vaddr, s1_valid)
270  val s2_pf_l2_incr_vaddr = RegEnable(s1_pf_l2_incr_vaddr, s1_valid)
271  val s2_pf_l2_decr_vaddr = RegEnable(s1_pf_l2_decr_vaddr, s1_valid)
272  val s2_pf_l3_incr_vaddr = RegEnable(s1_pf_l3_incr_vaddr, s1_valid)
273  val s2_pf_l3_decr_vaddr = RegEnable(s1_pf_l3_decr_vaddr, s1_valid)
274  val s2_can_send_pf = RegEnable(s1_can_send_pf, s1_valid)
275  val s2_active = array(s2_index).active
276  val s2_decr_mode = array(s2_index).decr_mode
277  val s2_l1_vaddr = Mux(s2_decr_mode, s2_pf_l1_decr_vaddr, s2_pf_l1_incr_vaddr)
278  val s2_l2_vaddr = Mux(s2_decr_mode, s2_pf_l2_decr_vaddr, s2_pf_l2_incr_vaddr)
279  val s2_l3_vaddr = Mux(s2_decr_mode, s2_pf_l3_decr_vaddr, s2_pf_l3_incr_vaddr)
280  val s2_will_send_pf = s2_valid && s2_active && s2_can_send_pf
281  val s2_pf_req_valid = s2_will_send_pf && io.enable
282  val s2_pf_l1_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
283    valid = s2_valid,
284    vaddr = s2_l1_vaddr,
285    width = WIDTH_CACHE_BLOCKS,
286    decr_mode = s2_decr_mode,
287    sink = SINK_L1,
288    source = L1_HW_PREFETCH_STREAM)
289  val s2_pf_l2_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
290    valid = s2_valid,
291    vaddr = s2_l2_vaddr,
292    width = L2_WIDTH_CACHE_BLOCKS,
293    decr_mode = s2_decr_mode,
294    sink = SINK_L2,
295    source = L1_HW_PREFETCH_STREAM)
296  val s2_pf_l3_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
297    valid = s2_valid,
298    vaddr = s2_l3_vaddr,
299    width = L3_WIDTH_CACHE_BLOCKS,
300    decr_mode = s2_decr_mode,
301    sink = SINK_L3,
302    source = L1_HW_PREFETCH_STREAM)
303
304  XSPerfAccumulate("s2_valid", s2_valid)
305  XSPerfAccumulate("s2_will_not_send_pf", s2_valid && !s2_will_send_pf)
306  XSPerfAccumulate("s2_will_send_decr_pf", s2_valid && s2_will_send_pf && s2_decr_mode)
307  XSPerfAccumulate("s2_will_send_incr_pf", s2_valid && s2_will_send_pf && !s2_decr_mode)
308
309  // s3: send the l1 prefetch req out
310  val s3_pf_l1_valid = RegNext(s2_pf_req_valid)
311  val s3_pf_l1_bits = RegEnable(s2_pf_l1_req_bits, s2_pf_req_valid)
312  val s3_pf_l2_valid = RegNext(s2_pf_req_valid)
313  val s3_pf_l2_bits = RegEnable(s2_pf_l2_req_bits, s2_pf_req_valid)
314  val s3_pf_l3_bits = RegEnable(s2_pf_l3_req_bits, s2_pf_req_valid)
315
316  XSPerfAccumulate("s3_pf_sent", s3_pf_l1_valid)
317
318  // s4: send the l2 prefetch req out
319  val s4_pf_l2_valid = RegNext(s3_pf_l2_valid)
320  val s4_pf_l2_bits = RegEnable(s3_pf_l2_bits, s3_pf_l2_valid)
321  val s4_pf_l3_bits = RegEnable(s3_pf_l3_bits, s3_pf_l2_valid)
322
323  val enable_l3_pf = WireInit(Constantin.createRecord("enableL3StreamPrefetch" + p(XSCoreParamsKey).HartId.toString, initValue = 0.U)) =/= 0.U
324  // s5: send the l3 prefetch req out
325  val s5_pf_l3_valid = RegNext(s4_pf_l2_valid) && enable_l3_pf
326  val s5_pf_l3_bits = RegEnable(s4_pf_l3_bits, s4_pf_l2_valid)
327
328  io.prefetch_req.valid := s3_pf_l1_valid || s4_pf_l2_valid || s5_pf_l3_valid
329  io.prefetch_req.bits := Mux(s3_pf_l1_valid, s3_pf_l1_bits, Mux(s4_pf_l2_valid, s4_pf_l2_bits, s5_pf_l3_bits))
330
331  XSPerfAccumulate("s4_pf_sent", !s3_pf_l1_valid && s4_pf_l2_valid)
332  XSPerfAccumulate("s4_pf_blocked", s3_pf_l1_valid && s4_pf_l2_valid)
333  XSPerfAccumulate("pf_sent", io.prefetch_req.valid)
334
335  // Stride lookup starts here
336  // S0: Stride send req
337  val s0_lookup_valid = io.stream_lookup_req.valid
338  val s0_lookup_vaddr = io.stream_lookup_req.bits.vaddr
339  val s0_lookup_tag = get_region_tag(s0_lookup_vaddr)
340  // S1: match
341  val s1_lookup_valid = RegNext(s0_lookup_valid)
342  val s1_lookup_tag = RegEnable(s0_lookup_tag, s0_lookup_valid)
343  val s1_lookup_tag_match_vec = array.map(_.tag_match(s1_lookup_tag))
344  val s1_lookup_hit = VecInit(s1_lookup_tag_match_vec).asUInt.orR
345  val s1_lookup_index = OHToUInt(VecInit(s1_lookup_tag_match_vec))
346  // S2: read active out
347  val s2_lookup_valid = RegNext(s1_lookup_valid)
348  val s2_lookup_hit = RegEnable(s1_lookup_hit, s1_lookup_valid)
349  val s2_lookup_index = RegEnable(s1_lookup_index, s1_lookup_valid)
350  val s2_lookup_active = array(s2_lookup_index).active
351  // S3: send back to Stride
352  val s3_lookup_valid = RegNext(s2_lookup_valid)
353  val s3_lookup_hit = RegEnable(s2_lookup_hit, s2_lookup_valid)
354  val s3_lookup_active = RegEnable(s2_lookup_active, s2_lookup_valid)
355  io.stream_lookup_resp := s3_lookup_valid && s3_lookup_hit && s3_lookup_active
356
357  // reset meta to avoid muti-hit problem
358  for(i <- 0 until BIT_VEC_ARRAY_SIZE) {
359    when(reset.asBool || RegNext(io.flush)) {
360      array(i).reset(i)
361    }
362  }
363
364  XSPerfHistogram("bit_vector_active", PopCount(VecInit(array.map(_.active)).asUInt), true.B, 0, BIT_VEC_ARRAY_SIZE, 1)
365  XSPerfHistogram("bit_vector_decr_mode", PopCount(VecInit(array.map(_.decr_mode)).asUInt), true.B, 0, BIT_VEC_ARRAY_SIZE, 1)
366  XSPerfAccumulate("hash_conflict", s0_valid && s2_valid && (s0_region_tag =/= s2_region_tag) && (region_hash_tag(s0_region_tag) === region_hash_tag(s2_region_tag)))
367}