1package xiangshan.mem.prefetch 2 3import org.chipsalliance.cde.config.Parameters 4import chisel3._ 5import chisel3.util._ 6import xiangshan._ 7import utils._ 8import utility._ 9import xiangshan.cache.HasDCacheParameters 10import xiangshan.cache.mmu._ 11import xiangshan.mem.{L1PrefetchReq, LdPrefetchTrainBundle} 12import xiangshan.mem.trace._ 13import xiangshan.mem.L1PrefetchSource 14 15trait HasStreamPrefetchHelper extends HasL1PrefetchHelper { 16 // capacity related 17 val STREAM_FILTER_SIZE = 4 18 val BIT_VEC_ARRAY_SIZE = 16 19 val ACTIVE_THRESHOLD = BIT_VEC_WITDH - 4 20 val INIT_DEC_MODE = false 21 22 // bit_vector [StreamBitVectorBundle]: 23 // `X`: valid; `.`: invalid; `H`: hit 24 // [X X X X X X X X X . . H . X X X] [. . X X X X . . . . . . . . . .] 25 // hit in 12th slot & active ---------------------> prefetch bit_vector [StreamPrefetchReqBundle] 26 // | <---------------------------- depth ----------------------------> 27 // | <-- width -- > 28 val DEPTH_BYTES = 1024 29 val DEPTH_CACHE_BLOCKS = DEPTH_BYTES / dcacheParameters.blockBytes 30 val WIDTH_BYTES = 128 31 val WIDTH_CACHE_BLOCKS = WIDTH_BYTES / dcacheParameters.blockBytes 32 33 val L2_DEPTH_RATIO = 2 34 val L2_WIDTH_BYTES = WIDTH_BYTES * 2 35 val L2_WIDTH_CACHE_BLOCKS = L2_WIDTH_BYTES / dcacheParameters.blockBytes 36 37 val L3_DEPTH_RATIO = 3 38 val L3_WIDTH_BYTES = WIDTH_BYTES * 2 * 2 39 val L3_WIDTH_CACHE_BLOCKS = L3_WIDTH_BYTES / dcacheParameters.blockBytes 40 41 val DEPTH_LOOKAHEAD = 6 42 val DEPTH_BITS = log2Up(DEPTH_CACHE_BLOCKS) + DEPTH_LOOKAHEAD 43 44 val ENABLE_DECR_MODE = false 45 val ENABLE_STRICT_ACTIVE_DETECTION = true 46 47 // constraints 48 require((DEPTH_BYTES >= REGION_SIZE) && ((DEPTH_BYTES % REGION_SIZE) == 0) && ((DEPTH_BYTES / REGION_SIZE) > 0)) 49 require(((VADDR_HASH_WIDTH * 3) + BLK_ADDR_RAW_WIDTH) <= REGION_TAG_BITS) 50 require(WIDTH_BYTES >= dcacheParameters.blockBytes) 51} 52 53class StreamBitVectorBundle(implicit p: Parameters) extends XSBundle with HasStreamPrefetchHelper { 54 val tag = UInt(REGION_TAG_BITS.W) 55 val bit_vec = UInt(BIT_VEC_WITDH.W) 56 val active = Bool() 57 // cnt can be optimized 58 val cnt = UInt((log2Up(BIT_VEC_WITDH) + 1).W) 59 val decr_mode = Bool() 60 61 def reset(index: Int) = { 62 tag := index.U 63 bit_vec := 0.U 64 active := false.B 65 cnt := 0.U 66 decr_mode := INIT_DEC_MODE.B 67 } 68 69 def tag_match(new_tag: UInt): Bool = { 70 region_hash_tag(tag) === region_hash_tag(new_tag) 71 } 72 73 def alloc(alloc_tag: UInt, alloc_bit_vec: UInt, alloc_active: Bool, alloc_decr_mode: Bool) = { 74 tag := alloc_tag 75 bit_vec := alloc_bit_vec 76 active := alloc_active 77 cnt := 1.U 78 if(ENABLE_DECR_MODE) { 79 decr_mode := alloc_decr_mode 80 }else { 81 decr_mode := INIT_DEC_MODE.B 82 } 83 84 assert(PopCount(alloc_bit_vec) === 1.U, "alloc vector should be one hot") 85 } 86 87 def update(update_bit_vec: UInt, update_active: Bool) = { 88 // if the slot is 0 before, increment cnt 89 val cnt_en = !((bit_vec & update_bit_vec).orR) 90 val cnt_next = Mux(cnt_en, cnt + 1.U, cnt) 91 92 bit_vec := bit_vec | update_bit_vec 93 cnt := cnt_next 94 when(cnt_next >= ACTIVE_THRESHOLD.U) { 95 active := true.B 96 } 97 when(update_active) { 98 active := true.B 99 } 100 101 assert(PopCount(update_bit_vec) === 1.U, "update vector should be one hot") 102 assert(cnt <= BIT_VEC_WITDH.U, "cnt should always less than bit vector size") 103 } 104} 105 106class StreamPrefetchReqBundle(implicit p: Parameters) extends XSBundle with HasStreamPrefetchHelper { 107 val region = UInt(REGION_TAG_BITS.W) 108 val bit_vec = UInt(BIT_VEC_WITDH.W) 109 val sink = UInt(SINK_BITS.W) 110 val source = new L1PrefetchSource() 111 112 // align prefetch vaddr and width to region 113 def getStreamPrefetchReqBundle(valid: Bool, vaddr: UInt, width: Int, decr_mode: Bool, sink: UInt, source: UInt): StreamPrefetchReqBundle = { 114 val res = Wire(new StreamPrefetchReqBundle) 115 res.region := get_region_tag(vaddr) 116 res.sink := sink 117 res.source.value := source 118 119 val region_bits = get_region_bits(vaddr) 120 val region_bit_vec = UIntToOH(region_bits) 121 res.bit_vec := Mux( 122 decr_mode, 123 (0 until width).map{ case i => region_bit_vec >> i}.reduce(_ | _), 124 (0 until width).map{ case i => region_bit_vec << i}.reduce(_ | _) 125 ) 126 127 assert(!valid || PopCount(res.bit_vec) <= width.U, "actual prefetch block number should less than or equals to WIDTH_CACHE_BLOCKS") 128 assert(!valid || PopCount(res.bit_vec) >= 1.U, "at least one block should be included") 129 assert(sink <= SINK_L3, "invalid sink") 130 for(i <- 0 until BIT_VEC_WITDH) { 131 when(decr_mode) { 132 when(i.U > region_bits) { 133 assert(!valid || res.bit_vec(i) === 0.U, s"res.bit_vec(${i}) is not zero in decr_mode, prefetch vector is wrong!") 134 }.elsewhen(i.U === region_bits) { 135 assert(!valid || res.bit_vec(i) === 1.U, s"res.bit_vec(${i}) is zero in decr_mode, prefetch vector is wrong!") 136 } 137 }.otherwise { 138 when(i.U < region_bits) { 139 assert(!valid || res.bit_vec(i) === 0.U, s"res.bit_vec(${i}) is not zero in incr_mode, prefetch vector is wrong!") 140 }.elsewhen(i.U === region_bits) { 141 assert(!valid || res.bit_vec(i) === 1.U, s"res.bit_vec(${i}) is zero in decr_mode, prefetch vector is wrong!") 142 } 143 } 144 } 145 146 res 147 } 148} 149 150class StreamBitVectorArray(implicit p: Parameters) extends XSModule with HasStreamPrefetchHelper { 151 val io = IO(new XSBundle { 152 val enable = Input(Bool()) 153 // TODO: flush all entry when process changing happens, or disable stream prefetch for a while 154 val flush = Input(Bool()) 155 val dynamic_depth = Input(UInt(DEPTH_BITS.W)) 156 val train_req = Flipped(DecoupledIO(new PrefetchReqBundle)) 157 val prefetch_req = ValidIO(new StreamPrefetchReqBundle) 158 159 // Stride send lookup req here 160 val stream_lookup_req = Flipped(ValidIO(new PrefetchReqBundle)) 161 val stream_lookup_resp = Output(Bool()) 162 }) 163 164 val array = Reg(Vec(BIT_VEC_ARRAY_SIZE, new StreamBitVectorBundle)) 165 val replacement = ReplacementPolicy.fromString("plru", BIT_VEC_ARRAY_SIZE) 166 167 // s0: generate region tag, parallel match 168 val s0_can_accept = Wire(Bool()) 169 val s0_valid = io.train_req.fire 170 val s0_vaddr = io.train_req.bits.vaddr 171 val s0_region_bits = get_region_bits(s0_vaddr) 172 val s0_region_tag = get_region_tag(s0_vaddr) 173 val s0_region_tag_plus_one = get_region_tag(s0_vaddr) + 1.U 174 val s0_region_tag_minus_one = get_region_tag(s0_vaddr) - 1.U 175 val s0_region_tag_match_vec = array.map(_.tag_match(s0_region_tag)) 176 val s0_region_tag_plus_one_match_vec = array.map(_.tag_match(s0_region_tag_plus_one)) 177 val s0_region_tag_minus_one_match_vec = array.map(_.tag_match(s0_region_tag_minus_one)) 178 val s0_hit = Cat(s0_region_tag_match_vec).orR 179 val s0_plus_one_hit = Cat(s0_region_tag_plus_one_match_vec).orR 180 val s0_minus_one_hit = Cat(s0_region_tag_minus_one_match_vec).orR 181 val s0_hit_vec = VecInit(s0_region_tag_match_vec).asUInt 182 val s0_index = Mux(s0_hit, OHToUInt(s0_hit_vec), replacement.way) 183 val s0_plus_one_index = OHToUInt(VecInit(s0_region_tag_plus_one_match_vec).asUInt) 184 val s0_minus_one_index = OHToUInt(VecInit(s0_region_tag_minus_one_match_vec).asUInt) 185 io.train_req.ready := s0_can_accept 186 187 when(s0_valid) { 188 replacement.access(s0_index) 189 } 190 191 assert(!s0_valid || PopCount(VecInit(s0_region_tag_match_vec)) <= 1.U, "req region should match no more than 1 entry") 192 assert(!s0_valid || PopCount(VecInit(s0_region_tag_plus_one_match_vec)) <= 1.U, "req region plus 1 should match no more than 1 entry") 193 assert(!s0_valid || PopCount(VecInit(s0_region_tag_minus_one_match_vec)) <= 1.U, "req region minus 1 should match no more than 1 entry") 194 assert(!s0_valid || !(s0_hit && s0_plus_one_hit && (s0_index === s0_plus_one_index)), "region and region plus 1 index match failed") 195 assert(!s0_valid || !(s0_hit && s0_minus_one_hit && (s0_index === s0_minus_one_index)), "region and region minus 1 index match failed") 196 assert(!s0_valid || !(s0_plus_one_hit && s0_minus_one_hit && (s0_minus_one_index === s0_plus_one_index)), "region plus 1 and region minus 1 index match failed") 197 assert(!(s0_valid && RegNext(s0_valid) && !s0_hit && !RegNext(s0_hit) && replacement.way === RegNext(replacement.way)), "replacement error") 198 199 XSPerfAccumulate("s0_valid_train_req", s0_valid) 200 val s0_hit_pattern_vec = Seq(s0_hit, s0_plus_one_hit, s0_minus_one_hit) 201 for(i <- 0 until (1 << s0_hit_pattern_vec.size)) { 202 XSPerfAccumulate(s"s0_hit_pattern_${toBinary(i)}", (VecInit(s0_hit_pattern_vec).asUInt === i.U) && s0_valid) 203 } 204 XSPerfAccumulate("s0_replace_the_neighbor", s0_valid && !s0_hit && ((s0_plus_one_hit && (s0_index === s0_plus_one_index)) || (s0_minus_one_hit && (s0_index === s0_minus_one_index)))) 205 XSPerfAccumulate("s0_req_valid", io.train_req.valid) 206 XSPerfAccumulate("s0_req_cannot_accept", io.train_req.valid && !io.train_req.ready) 207 208 val ratio_const = WireInit(Constantin.createRecord("l2DepthRatio" + p(XSCoreParamsKey).HartId.toString, initValue = L2_DEPTH_RATIO.U)) 209 val ratio = ratio_const(3, 0) 210 211 val l3_ratio_const = WireInit(Constantin.createRecord("l3DepthRatio" + p(XSCoreParamsKey).HartId.toString, initValue = L3_DEPTH_RATIO.U)) 212 val l3_ratio = l3_ratio_const(3, 0) 213 214 // s1: alloc or update 215 val s1_valid = RegNext(s0_valid) 216 val s1_index = RegEnable(s0_index, s0_valid) 217 val s1_plus_one_index = RegEnable(s0_plus_one_index, s0_valid) 218 val s1_minus_one_index = RegEnable(s0_minus_one_index, s0_valid) 219 val s1_hit = RegEnable(s0_hit, s0_valid) 220 val s1_plus_one_hit = if(ENABLE_STRICT_ACTIVE_DETECTION) 221 RegEnable(s0_plus_one_hit, s0_valid) && array(s1_plus_one_index).active && (array(s1_plus_one_index).cnt >= ACTIVE_THRESHOLD.U) 222 else 223 RegEnable(s0_plus_one_hit, s0_valid) && array(s1_plus_one_index).active 224 val s1_minus_one_hit = if(ENABLE_STRICT_ACTIVE_DETECTION) 225 RegEnable(s0_minus_one_hit, s0_valid) && array(s1_minus_one_index).active && (array(s1_minus_one_index).cnt >= ACTIVE_THRESHOLD.U) 226 else 227 RegEnable(s0_minus_one_hit, s0_valid) && array(s1_minus_one_index).active 228 val s1_region_tag = RegEnable(s0_region_tag, s0_valid) 229 val s1_region_bits = RegEnable(s0_region_bits, s0_valid) 230 val s1_alloc = s1_valid && !s1_hit 231 val s1_update = s1_valid && s1_hit 232 val s1_pf_l1_incr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) + io.dynamic_depth, 0.U(BLOCK_OFFSET.W)) 233 val s1_pf_l1_decr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) - io.dynamic_depth, 0.U(BLOCK_OFFSET.W)) 234 val s1_pf_l2_incr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) + (io.dynamic_depth << ratio), 0.U(BLOCK_OFFSET.W)) 235 val s1_pf_l2_decr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) - (io.dynamic_depth << ratio), 0.U(BLOCK_OFFSET.W)) 236 val s1_pf_l3_incr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) + (io.dynamic_depth << l3_ratio), 0.U(BLOCK_OFFSET.W)) 237 val s1_pf_l3_decr_vaddr = Cat(region_to_block_addr(s1_region_tag, s1_region_bits) - (io.dynamic_depth << l3_ratio), 0.U(BLOCK_OFFSET.W)) 238 // TODO: remove this 239 val s1_can_send_pf = Mux(s1_update, !((array(s1_index).bit_vec & UIntToOH(s1_region_bits)).orR), true.B) 240 s0_can_accept := !(s1_valid && (region_hash_tag(s1_region_tag) === region_hash_tag(s0_region_tag))) 241 242 when(s1_alloc) { 243 // alloc a new entry 244 array(s1_index).alloc( 245 alloc_tag = s1_region_tag, 246 alloc_bit_vec = UIntToOH(s1_region_bits), 247 alloc_active = s1_plus_one_hit || s1_minus_one_hit, 248 alloc_decr_mode = RegEnable(s0_plus_one_hit, s0_valid)) 249 250 }.elsewhen(s1_update) { 251 // update a existing entry 252 assert(array(s1_index).cnt =/= 0.U || array(s1_index).tag === s1_index, "entry should have been allocated before") 253 array(s1_index).update( 254 update_bit_vec = UIntToOH(s1_region_bits), 255 update_active = s1_plus_one_hit || s1_minus_one_hit) 256 } 257 258 XSPerfAccumulate("s1_alloc", s1_alloc) 259 XSPerfAccumulate("s1_update", s1_update) 260 XSPerfAccumulate("s1_active_plus_one_hit", s1_valid && s1_plus_one_hit) 261 XSPerfAccumulate("s1_active_minus_one_hit", s1_valid && s1_minus_one_hit) 262 263 // s2: trigger prefetch if hit active bit vector, compute meta of prefetch req 264 val s2_valid = RegNext(s1_valid) 265 val s2_index = RegEnable(s1_index, s1_valid) 266 val s2_region_bits = RegEnable(s1_region_bits, s1_valid) 267 val s2_region_tag = RegEnable(s1_region_tag, s1_valid) 268 val s2_pf_l1_incr_vaddr = RegEnable(s1_pf_l1_incr_vaddr, s1_valid) 269 val s2_pf_l1_decr_vaddr = RegEnable(s1_pf_l1_decr_vaddr, s1_valid) 270 val s2_pf_l2_incr_vaddr = RegEnable(s1_pf_l2_incr_vaddr, s1_valid) 271 val s2_pf_l2_decr_vaddr = RegEnable(s1_pf_l2_decr_vaddr, s1_valid) 272 val s2_pf_l3_incr_vaddr = RegEnable(s1_pf_l3_incr_vaddr, s1_valid) 273 val s2_pf_l3_decr_vaddr = RegEnable(s1_pf_l3_decr_vaddr, s1_valid) 274 val s2_can_send_pf = RegEnable(s1_can_send_pf, s1_valid) 275 val s2_active = array(s2_index).active 276 val s2_decr_mode = array(s2_index).decr_mode 277 val s2_l1_vaddr = Mux(s2_decr_mode, s2_pf_l1_decr_vaddr, s2_pf_l1_incr_vaddr) 278 val s2_l2_vaddr = Mux(s2_decr_mode, s2_pf_l2_decr_vaddr, s2_pf_l2_incr_vaddr) 279 val s2_l3_vaddr = Mux(s2_decr_mode, s2_pf_l3_decr_vaddr, s2_pf_l3_incr_vaddr) 280 val s2_will_send_pf = s2_valid && s2_active && s2_can_send_pf 281 val s2_pf_req_valid = s2_will_send_pf && io.enable 282 val s2_pf_l1_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle( 283 valid = s2_valid, 284 vaddr = s2_l1_vaddr, 285 width = WIDTH_CACHE_BLOCKS, 286 decr_mode = s2_decr_mode, 287 sink = SINK_L1, 288 source = L1_HW_PREFETCH_STREAM) 289 val s2_pf_l2_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle( 290 valid = s2_valid, 291 vaddr = s2_l2_vaddr, 292 width = L2_WIDTH_CACHE_BLOCKS, 293 decr_mode = s2_decr_mode, 294 sink = SINK_L2, 295 source = L1_HW_PREFETCH_STREAM) 296 val s2_pf_l3_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle( 297 valid = s2_valid, 298 vaddr = s2_l3_vaddr, 299 width = L3_WIDTH_CACHE_BLOCKS, 300 decr_mode = s2_decr_mode, 301 sink = SINK_L3, 302 source = L1_HW_PREFETCH_STREAM) 303 304 XSPerfAccumulate("s2_valid", s2_valid) 305 XSPerfAccumulate("s2_will_not_send_pf", s2_valid && !s2_will_send_pf) 306 XSPerfAccumulate("s2_will_send_decr_pf", s2_valid && s2_will_send_pf && s2_decr_mode) 307 XSPerfAccumulate("s2_will_send_incr_pf", s2_valid && s2_will_send_pf && !s2_decr_mode) 308 309 // s3: send the l1 prefetch req out 310 val s3_pf_l1_valid = RegNext(s2_pf_req_valid) 311 val s3_pf_l1_bits = RegEnable(s2_pf_l1_req_bits, s2_pf_req_valid) 312 val s3_pf_l2_valid = RegNext(s2_pf_req_valid) 313 val s3_pf_l2_bits = RegEnable(s2_pf_l2_req_bits, s2_pf_req_valid) 314 val s3_pf_l3_bits = RegEnable(s2_pf_l3_req_bits, s2_pf_req_valid) 315 316 XSPerfAccumulate("s3_pf_sent", s3_pf_l1_valid) 317 318 // s4: send the l2 prefetch req out 319 val s4_pf_l2_valid = RegNext(s3_pf_l2_valid) 320 val s4_pf_l2_bits = RegEnable(s3_pf_l2_bits, s3_pf_l2_valid) 321 val s4_pf_l3_bits = RegEnable(s3_pf_l3_bits, s3_pf_l2_valid) 322 323 val enable_l3_pf = WireInit(Constantin.createRecord("enableL3StreamPrefetch" + p(XSCoreParamsKey).HartId.toString, initValue = 0.U)) =/= 0.U 324 // s5: send the l3 prefetch req out 325 val s5_pf_l3_valid = RegNext(s4_pf_l2_valid) && enable_l3_pf 326 val s5_pf_l3_bits = RegEnable(s4_pf_l3_bits, s4_pf_l2_valid) 327 328 io.prefetch_req.valid := s3_pf_l1_valid || s4_pf_l2_valid || s5_pf_l3_valid 329 io.prefetch_req.bits := Mux(s3_pf_l1_valid, s3_pf_l1_bits, Mux(s4_pf_l2_valid, s4_pf_l2_bits, s5_pf_l3_bits)) 330 331 XSPerfAccumulate("s4_pf_sent", !s3_pf_l1_valid && s4_pf_l2_valid) 332 XSPerfAccumulate("s4_pf_blocked", s3_pf_l1_valid && s4_pf_l2_valid) 333 XSPerfAccumulate("pf_sent", io.prefetch_req.valid) 334 335 // Stride lookup starts here 336 // S0: Stride send req 337 val s0_lookup_valid = io.stream_lookup_req.valid 338 val s0_lookup_vaddr = io.stream_lookup_req.bits.vaddr 339 val s0_lookup_tag = get_region_tag(s0_lookup_vaddr) 340 // S1: match 341 val s1_lookup_valid = RegNext(s0_lookup_valid) 342 val s1_lookup_tag = RegEnable(s0_lookup_tag, s0_lookup_valid) 343 val s1_lookup_tag_match_vec = array.map(_.tag_match(s1_lookup_tag)) 344 val s1_lookup_hit = VecInit(s1_lookup_tag_match_vec).asUInt.orR 345 val s1_lookup_index = OHToUInt(VecInit(s1_lookup_tag_match_vec)) 346 // S2: read active out 347 val s2_lookup_valid = RegNext(s1_lookup_valid) 348 val s2_lookup_hit = RegEnable(s1_lookup_hit, s1_lookup_valid) 349 val s2_lookup_index = RegEnable(s1_lookup_index, s1_lookup_valid) 350 val s2_lookup_active = array(s2_lookup_index).active 351 // S3: send back to Stride 352 val s3_lookup_valid = RegNext(s2_lookup_valid) 353 val s3_lookup_hit = RegEnable(s2_lookup_hit, s2_lookup_valid) 354 val s3_lookup_active = RegEnable(s2_lookup_active, s2_lookup_valid) 355 io.stream_lookup_resp := s3_lookup_valid && s3_lookup_hit && s3_lookup_active 356 357 // reset meta to avoid muti-hit problem 358 for(i <- 0 until BIT_VEC_ARRAY_SIZE) { 359 when(reset.asBool || RegNext(io.flush)) { 360 array(i).reset(i) 361 } 362 } 363 364 XSPerfHistogram("bit_vector_active", PopCount(VecInit(array.map(_.active)).asUInt), true.B, 0, BIT_VEC_ARRAY_SIZE, 1) 365 XSPerfHistogram("bit_vector_decr_mode", PopCount(VecInit(array.map(_.decr_mode)).asUInt), true.B, 0, BIT_VEC_ARRAY_SIZE, 1) 366 XSPerfAccumulate("hash_conflict", s0_valid && s2_valid && (s0_region_tag =/= s2_region_tag) && (region_hash_tag(s0_region_tag) === region_hash_tag(s2_region_tag))) 367}