xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VFPU.scala (revision 40767ba326fd3d393ed4428288fdc009ea1cbc85)
1/****************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ****************************************************************************************
16  */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3.{Mux, _}
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.{VectorFloatAdder,VectorFloatFMA,VectorFloatDivider}
27import yunsuan.VfpuType
28import xiangshan.{FuOpType, SrcType, XSBundle, XSCoreParamsKey, XSModule}
29import xiangshan.backend.fu.fpu.FPUSubModule
30
31class VFPU(implicit p: Parameters) extends FPUSubModule(p(XSCoreParamsKey).VLEN){
32  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VfpuType.dummy, "VFPU OpType not supported")
33  XSError(io.in.valid && (io.in.bits.uop.ctrl.vconfig.vtype.vsew === 0.U), "8 bits not supported in VFPU")
34  override val dataModule = null // Only use IO, not dataModule
35
36// rename signal
37  val in = io.in.bits
38  val ctrl = io.in.bits.uop.ctrl
39  val vtype = ctrl.vconfig.vtype
40  val src1Type = io.in.bits.uop.ctrl.srcType
41  val uopIdx = ctrl.uopIdx
42
43// def some signal
44  val fflagsReg = RegInit(0.U(5.W))
45  val fflagsWire = WireInit(0.U(5.W))
46  val dataReg = Reg(io.out.bits.data.cloneType)
47  val dataWire = Wire(dataReg.cloneType)
48  val s_idle :: s_compute :: s_finish :: Nil = Enum(3)
49  val state = RegInit(s_idle)
50  val vfalu = Module(new VfaluWrapper()(p))
51  val vfmacc = Module(new VfmaccWrapper()(p))
52  val vfdiv = Module(new VfdivWrapper()(p))
53  val outValid = vfalu.io.out.valid || vfmacc.io.out.valid || vfdiv.io.out.valid
54  val outFire = vfalu.io.out.fire() || vfmacc.io.out.fire() || vfdiv.io.out.fire()
55
56// reg input signal
57  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
58  val s0_maskReg = Reg(UInt(8.W))
59  val inHs = io.in.fire()
60  when(inHs && state===s_idle){
61    s0_uopReg := io.in.bits.uop
62    s0_maskReg := Fill(8, 1.U(1.W))
63  }
64
65// fsm
66  switch (state) {
67    is (s_idle) {
68      state := Mux(inHs, s_compute, s_idle)
69    }
70    is (s_compute) {
71      state := Mux(outValid, Mux(outFire, s_idle, s_finish),
72                             s_compute)
73    }
74    is (s_finish) {
75      state := Mux(io.out.fire(), s_idle, s_finish)
76    }
77  }
78  fflagsReg := Mux(outValid, fflagsWire, fflagsReg)
79  dataReg := Mux(outValid, dataWire, dataReg)
80
81// connect the input port of vfalu
82  vfalu.io.in.bits.src <> in.src
83  vfalu.io.in.bits.srcType <> in.uop.ctrl.srcType
84  vfalu.io.in.bits.round_mode := rm
85  vfalu.io.in.bits.fp_format := vtype.vsew(1,0)
86  vfalu.io.in.bits.uopIdx := uopIdx(0) //TODO
87  vfalu.io.in.bits.opb_widening := false.B // TODO
88  vfalu.io.in.bits.res_widening := false.B // TODO
89  vfalu.io.in.bits.op_code := ctrl.fuOpType
90  vfalu.io.ready_out.s0_mask := s0_maskReg
91  vfalu.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
92  vfalu.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
93
94//  connect the input port of vfmacc
95  vfmacc.io.in.bits.src <> in.src
96  vfmacc.io.in.bits.srcType <> in.uop.ctrl.srcType
97  vfmacc.io.in.bits.round_mode := rm
98  vfmacc.io.in.bits.fp_format := vtype.vsew(1, 0)
99  vfmacc.io.in.bits.uopIdx := uopIdx(0) //TODO
100  vfmacc.io.in.bits.opb_widening := DontCare // TODO
101  vfmacc.io.in.bits.res_widening := false.B // TODO
102  vfmacc.io.in.bits.op_code := DontCare
103  vfmacc.io.ready_out.s0_mask := s0_maskReg
104  vfmacc.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
105  vfmacc.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
106
107  //  connect the input port of vfdiv
108  vfdiv.io.in.bits.src <> in.src
109  vfdiv.io.in.bits.srcType <> in.uop.ctrl.srcType
110  vfdiv.io.in.bits.round_mode := rm
111  vfdiv.io.in.bits.fp_format := vtype.vsew(1, 0)
112  vfdiv.io.in.bits.uopIdx := uopIdx(0) // TODO
113  vfdiv.io.in.bits.opb_widening := DontCare // TODO
114  vfdiv.io.in.bits.res_widening := DontCare // TODO
115  vfdiv.io.in.bits.op_code := DontCare
116  vfdiv.io.ready_out.s0_mask := s0_maskReg
117  vfdiv.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
118  vfdiv.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
119
120// connect the output port
121  fflagsWire := Mux1H(
122    Seq(
123      (s0_uopReg.ctrl.fuOpType === VfpuType.isVfalu ) -> vfalu.io.out.bits.fflags,
124      (s0_uopReg.ctrl.fuOpType === VfpuType.isVfmacc) -> vfmacc.io.out.bits.fflags,
125      (s0_uopReg.ctrl.fuOpType === VfpuType.isVfdiv ) -> vfdiv.io.out.bits.fflags
126    )
127  )
128  fflags := Mux(state === s_compute && outFire, fflagsWire, fflagsReg)
129  dataWire := Mux1H(
130    Seq(
131      (s0_uopReg.ctrl.fuOpType === VfpuType.isVfalu ) -> vfalu.io.out.bits.result,
132      (s0_uopReg.ctrl.fuOpType === VfpuType.isVfmacc) -> vfmacc.io.out.bits.result,
133      (s0_uopReg.ctrl.fuOpType === VfpuType.isVfdiv ) -> vfdiv.io.out.bits.result
134    )
135  )
136  io.out.bits.data := Mux(state === s_compute && outFire, dataWire, dataReg)
137  io.out.bits.uop := s0_uopReg
138  // valid/ready
139  vfalu.io.in.valid  := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.isVfalu  && state === s_idle
140  vfmacc.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.isVfmacc && state === s_idle
141  vfdiv.io.in.valid  := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.isVfdiv  && state === s_idle
142  io.out.valid := state === s_compute && outValid || state === s_finish
143  vfalu.io.out.ready := io.out.ready
144  vfmacc.io.out.ready := io.out.ready
145  vfdiv.io.out.ready := io.out.ready
146  io.in.ready := state === s_idle
147}
148
149class VFPUWraaperBundle (implicit p: Parameters)  extends XSBundle{
150  val in = Flipped(DecoupledIO(Output(new Bundle {
151    val src = Vec(4, Input(UInt(VLEN.W)))
152    val srcType = Vec(4, SrcType())
153
154    val round_mode = UInt(3.W)
155    val fp_format = UInt(2.W) // vsew
156    val uopIdx = Bool()
157    val opb_widening = Bool()
158    val res_widening = Bool()
159    val op_code = FuOpType()
160  })))
161
162  val ready_out = Input(new Bundle {
163    val s0_mask = UInt((VLEN / 16).W)
164    val s0_sew = UInt(2.W)
165    val s0_vl = UInt(8.W)
166  })
167
168  val out = DecoupledIO(Output(new Bundle {
169    val result = UInt(128.W)
170    val fflags = UInt(5.W)
171  }))
172}
173
174class VfdivWrapper(implicit p: Parameters)  extends XSModule{
175  val Latency = List(5, 7, 12)
176  val AdderWidth = XLEN
177  val NumAdder = VLEN / XLEN
178
179  val io = IO(new VFPUWraaperBundle)
180
181  val in = io.in.bits
182  val out = io.out.bits
183  val inHs = io.in.fire()
184
185  val s0_mask = io.ready_out.s0_mask
186  val s0_sew = io.ready_out.s0_sew
187  val s0_vl = io.ready_out.s0_vl
188
189  val vfdiv = Seq.fill(NumAdder)(Module(new VectorFloatDivider()))
190  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
191  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
192  for (i <- 0 until NumAdder) {
193    vfdiv(i).io.opb_i := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
194    vfdiv(i).io.opa_i := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
195    vfdiv(i).io.is_vec_i := true.B // If you can enter, it must be vector
196    vfdiv(i).io.frs2_i    := in.src(0)(63,0) // f[rs2]
197    vfdiv(i).io.frs1_i    := in.src(1)(63,0) // f[rs1]
198    vfdiv(i).io.is_frs1_i := false.B // if true, vs2 / f[rs1]
199    vfdiv(i).io.is_frs2_i := false.B // if true, f[rs2] / vs1
200    vfdiv(i).io.is_sqrt_i := false.B // must false, not support sqrt now
201    vfdiv(i).io.rm_i := in.round_mode
202    vfdiv(i).io.fp_format_i := Mux(inHs, in.fp_format, 3.U(2.W))
203    vfdiv(i).io.start_valid_i := io.in.valid
204    vfdiv(i).io.finish_ready_i := io.out.ready
205    vfdiv(i).io.flush_i := false.B  // TODO
206  }
207
208  val s4_fflagsVec = VecInit(vfdiv.map(_.io.fflags_o)).asUInt()
209  val s4_fflags16vl = fflagsGen(s0_mask, s4_fflagsVec, List.range(0, 8))
210  val s4_fflags32vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 1, 4, 5))
211  val s4_fflags64vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 4))
212  val s4_fflags = LookupTree(s0_sew(1, 0), List(
213    "b01".U -> Mux(s0_vl.orR, s4_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
214    "b10".U -> Mux(s0_vl.orR, s4_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
215    "b11".U -> Mux(s0_vl.orR, s4_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
216  ))
217  out.fflags := s4_fflags
218
219  val s4_result = VecInit(vfdiv.map(_.io.fpdiv_res_o)).asUInt()
220  out.result := s4_result
221
222  io.in.ready := VecInit(vfdiv.map(_.io.start_ready_o)).asUInt().andR()
223  io.out.valid := VecInit(vfdiv.map(_.io.finish_valid_o)).asUInt().andR()
224}
225
226class VfmaccWrapper(implicit p: Parameters)  extends XSModule{
227  val Latency = 3
228  val AdderWidth = XLEN
229  val NumAdder = VLEN / XLEN
230
231  val io = IO(new VFPUWraaperBundle)
232
233  val in = io.in.bits
234  val out = io.out.bits
235  val inHs = io.in.fire()
236
237  val validPipe = Seq.fill(Latency)(RegInit(false.B))
238  validPipe.zipWithIndex.foreach {
239    case (valid, idx) =>
240      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
241      valid := _valid
242  }
243  val s0_mask = io.ready_out.s0_mask
244  val s0_sew = io.ready_out.s0_sew
245  val s0_vl = io.ready_out.s0_vl
246
247  val vfmacc = Seq.fill(NumAdder)(Module(new VectorFloatFMA()))
248  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
249  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
250  val src3 = Mux(in.srcType(2) === SrcType.vp, in.src(2), VecExtractor(in.fp_format, in.src(2)))
251  for (i <- 0 until NumAdder) {
252    vfmacc(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
253    vfmacc(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
254    vfmacc(i).io.fp_c := Mux(inHs, src3(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
255    vfmacc(i).io.widen_b := Mux(inHs, Cat(src1((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src1((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U)
256    vfmacc(i).io.widen_a := Mux(inHs, Cat(src2((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src2((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U)
257    vfmacc(i).io.frs1 := in.src(0)(63,0)
258    vfmacc(i).io.is_frs1 := false.B // TODO: support vf inst
259    vfmacc(i).io.uop_idx := in.uopIdx // TODO
260    vfmacc(i).io.op_code := in.op_code(3,0)
261    vfmacc(i).io.is_vec := true.B // If you can enter, it must be vector
262    vfmacc(i).io.round_mode := in.round_mode
263    vfmacc(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
264    vfmacc(i).io.res_widening := in.res_widening // TODO
265  }
266
267  // output signal generation
268  val s2_fflagsVec = VecInit(vfmacc.map(_.io.fflags)).asUInt()
269  val s2_fflags16vl = fflagsGen(s0_mask, s2_fflagsVec, List.range(0, 8))
270  val s2_fflags32vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 1, 4, 5))
271  val s2_fflags64vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 4))
272  val s2_fflags = LookupTree(s0_sew(1, 0), List(
273    "b01".U -> Mux(s0_vl.orR, s2_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
274    "b10".U -> Mux(s0_vl.orR, s2_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
275    "b11".U -> Mux(s0_vl.orR, s2_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
276  ))
277  out.fflags := s2_fflags
278
279  val s2_result = VecInit(vfmacc.map(_.io.fp_result)).asUInt()
280  out.result := s2_result
281
282  io.in.ready := true.B
283  io.out.valid := validPipe(Latency - 1)
284}
285
286class VfaluWrapper(implicit p: Parameters)  extends XSModule{
287  val Latency = 2
288  val AdderWidth = XLEN
289  val NumAdder = VLEN / XLEN
290
291  val io = IO(new VFPUWraaperBundle)
292
293  val in = io.in.bits
294  val out = io.out.bits
295  val inHs = io.in.fire()
296
297  // reg input signal
298  val validPipe = Seq.fill(Latency)(RegInit(false.B))
299  validPipe.zipWithIndex.foreach {
300    case (valid, idx) =>
301      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
302      valid := _valid
303  }
304  val s0_mask = io.ready_out.s0_mask
305  val s0_sew = io.ready_out.s0_sew
306  val s0_vl = io.ready_out.s0_vl
307
308  // connect the input signal
309  val vfalu = Seq.fill(NumAdder)(Module(new VectorFloatAdder()))
310  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
311  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
312  for (i <- 0 until NumAdder) {
313    vfalu(i).io.fp_b := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
314    vfalu(i).io.fp_a := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
315    vfalu(i).io.widen_b := Mux(inHs, Cat(src1((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src1((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U)
316    vfalu(i).io.widen_a := Mux(inHs, Cat(src2((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src2((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U)
317    vfalu(i).io.frs1 := in.src(0)(63, 0)
318    vfalu(i).io.is_frs1 := false.B // TODO: support vf inst
319    vfalu(i).io.mask := 0.U //TODO
320    vfalu(i).io.uop_idx := in.uopIdx //TODO
321    vfalu(i).io.is_vec := true.B // If you can enter, it must be vector
322    vfalu(i).io.round_mode := in.round_mode
323    vfalu(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
324    vfalu(i).io.opb_widening := in.opb_widening // TODO
325    vfalu(i).io.res_widening := in.res_widening // TODO
326    vfalu(i).io.op_code := in.op_code(4,0)
327  }
328
329  // output signal generation
330  val s0_fflagsVec = VecInit(vfalu.map(_.io.fflags)).asUInt()
331  val s0_fflags16vl = fflagsGen(s0_mask, s0_fflagsVec, List.range(0, 8))
332  val s0_fflags32vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 1, 4, 5))
333  val s0_fflags64vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 4))
334  val s0_fflags = LookupTree(s0_sew(1, 0), List(
335    "b01".U -> Mux(s0_vl.orR, s0_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
336    "b10".U -> Mux(s0_vl.orR, s0_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
337    "b11".U -> Mux(s0_vl.orR, s0_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
338  ))
339  val s1_fflags = RegEnable(s0_fflags, validPipe(Latency-2))
340  out.fflags := s1_fflags
341
342  val s0_result = VecInit(vfalu.map(_.io.fp_result)).asUInt()
343  val s1_result = RegEnable(s0_result, validPipe(Latency-2))
344  out.result := s1_result
345
346  io.in.ready := true.B
347  io.out.valid := validPipe(Latency-1)
348}
349
350object fflagsGen{
351  def fflagsGen(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
352    var num = idx.length
353    val fflags = Seq.fill(num)(Wire(UInt(5.W)))
354    fflags.zip(vmask(num-1, 0).asBools().reverse).zip(idx).foreach {
355      case ((fflags0, mask), id) =>
356        fflags0 := Mux(mask, fflagsResult(id*5+4,id*5+0), 0.U)
357    }
358    val fflagsVl = Wire(Vec(num,UInt(5.W)))
359    for (i <- 0 until num) {
360      val _fflags = if (i == 0) fflags(i) else (fflagsVl(i - 1) | fflags(i))
361      fflagsVl(i) := _fflags
362    }
363    fflagsVl
364  }
365
366  def apply(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
367    fflagsGen(vmask, fflagsResult, idx)
368  }
369}
370
371object VecExtractor{
372  def xf2v_sew(sew: UInt, xf:UInt): UInt = {
373    LookupTree(sew(1, 0), List(
374      "b00".U -> VecInit(Seq.fill(16)(xf(7, 0))).asUInt,
375      "b01".U -> VecInit(Seq.fill(8)(xf(15, 0))).asUInt,
376      "b10".U -> VecInit(Seq.fill(4)(xf(31, 0))).asUInt,
377      "b11".U -> VecInit(Seq.fill(2)(xf(63, 0))).asUInt,
378    ))
379  }
380
381  def apply(sew: UInt, xf: UInt): UInt = {
382    xf2v_sew(sew, xf)
383  }
384}