xref: /XiangShan/src/main/scala/xiangshan/mem/vector/VSegmentUnit.scala (revision 8fd7c6dca865112367dcca346cb7c9897ca9fd6c)
1/***************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ***************************************************************************************/
16
17package xiangshan.mem
18
19import org.chipsalliance.cde.config.Parameters
20import chisel3._
21import chisel3.util._
22import utils._
23import utility._
24import xiangshan._
25import xiangshan.backend.rob.RobPtr
26import xiangshan.backend.Bundles._
27import xiangshan.mem._
28import xiangshan.backend.fu.FuType
29import freechips.rocketchip.diplomacy.BufferParams
30import xiangshan.cache.mmu._
31import xiangshan.cache._
32import xiangshan.cache.wpu.ReplayCarry
33import xiangshan.backend.fu.util.SdtrigExt
34import xiangshan.ExceptionNO._
35import xiangshan.backend.fu.vector.Bundles.VConfig
36
37class VSegmentBundle(implicit p: Parameters) extends VLSUBundle
38{
39  val vaddr            = UInt(VAddrBits.W)
40  val uop              = new DynInst
41  val paddr            = UInt(PAddrBits.W)
42  val mask             = UInt(VLEN.W)
43  val valid            = Bool()
44  val alignedType      = UInt(alignTypeBits.W)
45  val vl               = UInt(elemIdxBits.W)
46  val vlmaxInVd        = UInt(elemIdxBits.W)
47  val vlmaxMaskInVd    = UInt(elemIdxBits.W)
48  // for exception
49  val vstart           = UInt(elemIdxBits.W)
50  val exceptionvaddr   = UInt(VAddrBits.W)
51  val exception_va     = Bool()
52  val exception_pa     = Bool()
53}
54
55class VSegmentUnit (implicit p: Parameters) extends VLSUModule
56  with HasDCacheParameters
57  with MemoryOpConstants
58  with SdtrigExt
59  with HasLoadHelper
60{
61  val io               = IO(new VSegmentUnitIO)
62
63  val maxSize          = VSegmentBufferSize
64
65  class VSegUPtr(implicit p: Parameters) extends CircularQueuePtr[VSegUPtr](maxSize){
66  }
67
68  object VSegUPtr {
69    def apply(f: Bool, v: UInt)(implicit p: Parameters): VSegUPtr = {
70      val ptr           = Wire(new VSegUPtr)
71      ptr.flag         := f
72      ptr.value        := v
73      ptr
74    }
75  }
76
77  // buffer uop
78  val instMicroOp       = Reg(new VSegmentBundle)
79  val data              = Reg(Vec(maxSize, UInt(VLEN.W)))
80  val uopIdx            = Reg(Vec(maxSize, UopIdx()))
81  val stride            = Reg(Vec(maxSize, UInt(VLEN.W)))
82  val allocated         = RegInit(VecInit(Seq.fill(maxSize)(false.B)))
83  val enqPtr            = RegInit(0.U.asTypeOf(new VSegUPtr))
84  val deqPtr            = RegInit(0.U.asTypeOf(new VSegUPtr))
85  val stridePtr         = WireInit(0.U.asTypeOf(new VSegUPtr)) // for select stride/index
86
87  val segmentIdx        = RegInit(0.U(elemIdxBits.W))
88  val fieldIdx          = RegInit(0.U(fieldBits.W))
89  val segmentOffset     = RegInit(0.U(VAddrBits.W))
90  val splitPtr          = RegInit(0.U.asTypeOf(new VSegUPtr)) // for select load/store data
91  val splitPtrNext      = WireInit(0.U.asTypeOf(new VSegUPtr))
92
93  val exception_va      = WireInit(false.B)
94  val exception_pa      = WireInit(false.B)
95
96  val maxSegIdx         = instMicroOp.vl
97  val maxNfields        = instMicroOp.uop.vpu.nf
98
99  XSError(segmentIdx > maxSegIdx, s"segmentIdx > vl, something error!\n")
100  XSError(fieldIdx > maxNfields, s"fieldIdx > nfields, something error!\n")
101
102  // Segment instruction's FSM
103  val s_idle :: s_flush_sbuffer_req :: s_wait_flush_sbuffer_resp :: s_tlb_req :: s_wait_tlb_resp :: s_pm ::s_cache_req :: s_cache_resp :: s_latch_and_merge_data :: s_finish :: Nil = Enum(10)
104  val state             = RegInit(s_idle)
105  val stateNext         = WireInit(s_idle)
106  val sbufferEmpty      = io.flush_sbuffer.empty
107
108  /**
109   * state update
110   */
111  state  := stateNext
112
113  /**
114   * state transfer
115   */
116  when(state === s_idle){
117    stateNext := Mux(isAfter(enqPtr, deqPtr), s_flush_sbuffer_req, s_idle)
118  }.elsewhen(state === s_flush_sbuffer_req){
119    stateNext := Mux(sbufferEmpty, s_tlb_req, s_wait_flush_sbuffer_resp) // if sbuffer is empty, go to query tlb
120
121  }.elsewhen(state === s_wait_flush_sbuffer_resp){
122    stateNext := Mux(sbufferEmpty, s_tlb_req, s_wait_flush_sbuffer_resp)
123
124  }.elsewhen(state === s_tlb_req){
125    stateNext := s_wait_tlb_resp
126
127  }.elsewhen(state === s_wait_tlb_resp){
128    stateNext := Mux(!io.dtlb.resp.bits.miss && io.dtlb.resp.fire, s_pm, s_tlb_req)
129
130  }.elsewhen(state === s_pm){
131    stateNext := Mux(exception_pa || exception_va, s_finish, s_cache_req)
132
133  }.elsewhen(state === s_cache_req){
134    stateNext := Mux(io.dcache.req.fire, s_cache_resp, s_cache_req)
135
136  }.elsewhen(state === s_cache_resp){
137    when(io.dcache.req.fire) {
138      when(io.dcache.resp.bits.miss) {
139        stateNext := s_cache_req
140      }.otherwise {
141        stateNext := s_latch_and_merge_data
142      }
143    }.otherwise{
144      stateNext := s_cache_resp
145    }
146
147  }.elsewhen(state === s_latch_and_merge_data){
148    when((segmentIdx === maxSegIdx) && (fieldIdx === maxNfields)){
149      stateNext := s_finish // segment instruction finish
150    }.otherwise{
151      stateNext := s_tlb_req // need continue
152    }
153
154  }.elsewhen(state === s_finish){ // writeback uop
155    stateNext := Mux(distanceBetween(enqPtr, deqPtr) === 0.U, s_idle, s_finish)
156
157  }.otherwise{
158    stateNext := s_idle
159    XSError(true.B, s"Unknown state!\n")
160  }
161
162  /*************************************************************************
163   *                            enqueue logic
164   *************************************************************************/
165  io.in.ready                         := true.B
166  val fuOpType                         = io.in.bits.uop.fuOpType
167  val vtype                            = io.in.bits.uop.vpu.vtype
168  val mop                              = fuOpType(6, 5)
169  val instType                         = Cat(true.B, mop)
170  val eew                              = io.in.bits.uop.vpu.veew
171  val sew                              = vtype.vsew
172  val lmul                             = vtype.vlmul
173  val vl                               = instMicroOp.vl
174  val vm                               = instMicroOp.uop.vpu.vm
175  val vstart                           = instMicroOp.uop.vpu.vstart
176  val srcMask                          = GenFlowMask(Mux(vm, Fill(VLEN, 1.U(1.W)), io.in.bits.src_mask), vstart, vl, true)
177  // first uop enqueue, we need to latch microOp of segment instruction
178  when(io.in.fire && !instMicroOp.valid){
179    val vlmaxInVd                      = GenVLMAX(Mux(lmul.asSInt > 0.S, 0.U, lmul), Mux(isIndexed(instType), sew(1, 0), eew(1, 0))) // element number in a vd
180    instMicroOp.vaddr                 := io.in.bits.src_rs1(VAddrBits - 1, 0)
181    instMicroOp.valid                 := true.B // if is first uop
182    instMicroOp.alignedType           := Mux(isIndexed(instType), sew(1, 0), eew(1, 0))
183    instMicroOp.uop                   := io.in.bits.uop
184    instMicroOp.mask                  := srcMask
185    instMicroOp.vstart                := 0.U
186    instMicroOp.vlmaxInVd             := vlmaxInVd
187    instMicroOp.vlmaxMaskInVd         := UIntToMask(vlmaxInVd, elemIdxBits) // for merge data
188    instMicroOp.vl                    := io.in.bits.src_vl.asTypeOf(VConfig()).vl
189    segmentOffset                     := 0.U
190    fieldIdx                          := 0.U
191  }
192  // latch data
193  when(io.in.fire){
194    data(enqPtr.value)                := io.in.bits.src_vs3
195    stride(enqPtr.value)              := io.in.bits.src_stride
196    uopIdx(enqPtr.value)              := io.in.bits.uop.vpu.vuopIdx
197  }
198
199  // update enqptr, only 1 port
200  when(io.in.fire){
201    enqPtr                            := enqPtr + 1.U
202  }
203
204  /*************************************************************************
205   *                            output logic
206   *************************************************************************/
207  // MicroOp
208  val baseVaddr                       = instMicroOp.vaddr
209  val alignedType                     = instMicroOp.alignedType
210  val fuType                          = instMicroOp.uop.fuType
211  val mask                            = instMicroOp.mask
212  val exceptionVec                    = instMicroOp.uop.exceptionVec
213  val issueEew                        = instMicroOp.uop.vpu.veew
214  val issueLmul                       = instMicroOp.uop.vpu.vtype.vlmul
215  val issueSew                        = instMicroOp.uop.vpu.vtype.vsew
216  val issueEmul                       = EewLog2(issueEew) - issueSew + issueLmul
217  val elemIdxInVd                     = segmentIdx & instMicroOp.vlmaxMaskInVd
218  val issueInstType                   = Cat(true.B, instMicroOp.uop.fuOpType(6, 5)) // always segment instruction
219  val issueVLMAXLog2                  = GenVLMAXLog2(
220                                                      Mux(issueLmul.asSInt > 0.S, 0.U, issueLmul),
221                                                      Mux(isIndexed(issueInstType), issueSew(1, 0), issueEew(1, 0))
222                                                    ) // max element number log2 in vd
223  val issueVlMax                      = instMicroOp.vlmaxInVd // max elementIdx in vd
224  val issueMaxIdxInIndex              = GenVLMAX(Mux(issueEmul.asSInt > 0.S, 0.U, issueEmul), issueEew) // index element index in index register
225  val issueMaxIdxInIndexMask          = UIntToMask(issueMaxIdxInIndex, elemIdxBits)
226  val issueMaxIdxInIndexLog2          = GenVLMAXLog2(Mux(issueEmul.asSInt > 0.S, 0.U, issueEmul), issueEew)
227  val issueIndexIdx                   = segmentIdx & issueMaxIdxInIndexMask
228
229  val indexStride                     = IndexAddr( // index for indexed instruction
230                                                    index = stride(stridePtr.value),
231                                                    flow_inner_idx = issueIndexIdx,
232                                                    eew = issueEew
233                                                  )
234  val realSegmentOffset               = Mux(isIndexed(issueInstType),
235                                            indexStride,
236                                            segmentOffset)
237  val vaddr                           = baseVaddr + (fieldIdx << alignedType).asUInt + realSegmentOffset
238  /**
239   * tlb req and tlb resq
240   */
241
242  // query DTLB IO Assign
243  io.dtlb.req                         := DontCare
244  io.dtlb.resp.ready                  := true.B
245  io.dtlb.req.valid                   := state === s_tlb_req
246  io.dtlb.req.bits.cmd                := Mux(FuType.isVLoad(fuType), TlbCmd.read, TlbCmd.write)
247  io.dtlb.req.bits.vaddr              := vaddr
248  io.dtlb.req.bits.size               := instMicroOp.alignedType(2,0)
249  io.dtlb.req.bits.memidx.is_ld       := FuType.isVLoad(fuType)
250  io.dtlb.req.bits.memidx.is_st       := FuType.isVStore(fuType)
251  io.dtlb.req.bits.debug.robIdx       := instMicroOp.uop.robIdx
252  io.dtlb.req.bits.no_translate       := false.B
253  io.dtlb.req.bits.debug.pc           := instMicroOp.uop.pc
254  io.dtlb.req.bits.debug.isFirstIssue := DontCare
255  io.dtlb.req_kill                    := false.B
256
257  // tlb resp
258  when(io.dtlb.resp.fire && state === s_wait_tlb_resp){
259      exceptionVec(storePageFault)    := io.dtlb.resp.bits.excp(0).pf.st
260      exceptionVec(loadPageFault)     := io.dtlb.resp.bits.excp(0).pf.ld
261      exceptionVec(storeAccessFault)  := io.dtlb.resp.bits.excp(0).af.st
262      exceptionVec(loadAccessFault)   := io.dtlb.resp.bits.excp(0).af.ld
263      when(!io.dtlb.resp.bits.miss){
264        instMicroOp.paddr             := io.dtlb.resp.bits.paddr(0)
265      }
266  }
267  // pmp
268  // NOTE: only handle load/store exception here, if other exception happens, don't send here
269  val pmp = WireInit(io.pmpResp)
270  when(state === s_pm){
271    exception_va := exceptionVec(storePageFault) || exceptionVec(loadPageFault) ||
272    exceptionVec(storeAccessFault) || exceptionVec(loadAccessFault)
273    exception_pa := pmp.st || pmp.ld
274
275    instMicroOp.exception_pa       := exception_pa
276    instMicroOp.exception_va       := exception_va
277    // update storeAccessFault bit
278    exceptionVec(loadAccessFault)  := exceptionVec(loadAccessFault) || pmp.ld
279    exceptionVec(storeAccessFault) := exceptionVec(storeAccessFault) || pmp.st
280
281    instMicroOp.exceptionvaddr     := vaddr
282    instMicroOp.vl                 := segmentIdx // for exception
283    instMicroOp.vstart             := segmentIdx // for exception
284  }
285
286  /**
287   * flush sbuffer IO Assign
288   */
289  io.flush_sbuffer.valid           := !sbufferEmpty && (state === s_flush_sbuffer_req)
290
291
292  /**
293   * merge data for load
294   */
295  val cacheData = io.dcache.resp.bits.data
296  val pickData  = rdataVecHelper(alignedType(1,0), cacheData)
297  val mergedData = mergeDataWithElemIdx(
298    oldData = data(splitPtr.value),
299    newData = Seq(pickData),
300    alignedType = alignedType(1,0),
301    elemIdx = Seq(elemIdxInVd),
302    valids = Seq(true.B)
303  )
304  when(state === s_latch_and_merge_data){
305    data(splitPtr.value) := mergedData
306  }
307  /**
308   * split data for store
309   * */
310  val splitData = genVSData(
311    data = data(splitPtr.value),
312    elemIdx = elemIdxInVd,
313    alignedType = alignedType
314  )
315  val flowData  = genVWdata(splitData, alignedType) // TODO: connect vstd, pass vector data
316  val wmask     = genVWmask(vaddr, alignedType(1, 0)) & mask(segmentIdx)
317
318  /**
319   * dcache req
320   */
321  io.dcache.req                    := DontCare
322  io.dcache.req.valid              := state === s_cache_req && FuType.isVLoad(fuType)
323  io.dcache.req.bits.cmd           := Mux(FuType.isVLoad(fuType), MemoryOpConstants.M_XRD, MemoryOpConstants.M_PFW)
324  io.dcache.req.bits.vaddr         := vaddr
325  io.dcache.req.bits.amo_mask      := Mux(FuType.isVLoad(fuType), mask, wmask)
326  io.dcache.req.bits.amo_data      := flowData
327  io.dcache.req.bits.source        := Mux(FuType.isVLoad(fuType), LOAD_SOURCE.U, STORE_SOURCE.U)
328  io.dcache.req.bits.id            := DontCare
329
330  /**
331   * update ptr
332   * */
333
334  val splitPtrOffset = Mux(lmul.asSInt < 0.S, 1.U, (1.U << lmul).asUInt)
335  splitPtrNext := PriorityMux(Seq(
336    ((fieldIdx === maxNfields) && (elemIdxInVd === (issueVlMax - 1.U)))   -> (deqPtr +                     // segment finish and need access next register in group
337                                                                             (segmentIdx >> issueVLMAXLog2).asUInt),
338    (fieldIdx === maxNfields)                                             -> deqPtr,                       // segment finish
339    true.B                                                                -> (splitPtr + splitPtrOffset)   // next field
340  ))
341
342  // update splitPtr
343  when(state === s_latch_and_merge_data){
344    splitPtr := splitPtrNext
345  }.elsewhen(io.in.fire && !instMicroOp.valid){
346    splitPtr := deqPtr // initial splitPtr
347  }
348
349  // update stridePtr, only use in index
350  val strideOffset = Mux(isIndexed(issueInstType), segmentIdx >> issueMaxIdxInIndexLog2, 0.U)
351  stridePtr       := deqPtr + strideOffset
352
353  // update fieldIdx
354  when(fieldIdx === maxNfields && state === s_latch_and_merge_data){
355    fieldIdx := 0.U
356  }.elsewhen(state === s_latch_and_merge_data){
357    fieldIdx := fieldIdx + 1.U
358  }
359  //update segmentOffset
360  when(fieldIdx === maxNfields && state === s_latch_and_merge_data){
361    segmentOffset := segmentOffset + Mux(isUnitStride(issueInstType), (maxNfields +& 1.U) << issueEew, stride(stridePtr.value))
362  }
363
364  //update deqPtr
365  when(io.uopwriteback.fire){
366    deqPtr := deqPtr + 1.U
367  }
368
369  /*************************************************************************
370   *                            dequeue logic
371   *************************************************************************/
372  when(stateNext === s_idle){
373    instMicroOp.valid := false.B
374  }
375  io.uopwriteback.valid               := state === s_finish
376  io.uopwriteback.bits.uop            := instMicroOp.uop
377  io.uopwriteback.bits.mask.get       := instMicroOp.mask
378  io.uopwriteback.bits.data           := data(deqPtr.value)
379  io.uopwriteback.bits.vdIdx.get      := uopIdx(deqPtr.value)
380  io.uopwriteback.bits.uop.vpu.vl     := instMicroOp.vl
381  io.uopwriteback.bits.uop.vpu.vstart := instMicroOp.vstart
382  io.uopwriteback.bits.debug          := DontCare
383  io.uopwriteback.bits.vdIdxInField.get := DontCare
384
385  //to RS
386  io.feedback.valid                   := state === s_finish
387  io.feedback.bits.hit                := true.B
388  io.feedback.bits.robIdx             := instMicroOp.uop.robIdx
389  io.feedback.bits.sourceType         := DontCare
390  io.feedback.bits.flushState         := DontCare
391  io.feedback.bits.dataInvalidSqIdx   := DontCare
392  io.feedback.bits.uopIdx.get         := uopIdx(deqPtr.value)
393
394  // exception
395  io.exceptionAddr                    := DontCare // TODO: fix it when handle exception
396}
397
398