xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VFPU.scala (revision 822120df13d14b93718f77e176868bc9ef203df2)
1/****************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ****************************************************************************************
16  */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3.{Mux, _}
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.{VectorFloatAdder,VectorFloatFMA,VectorFloatDivider}
27import yunsuan.VfpuType
28import xiangshan.{FuOpType, SrcType, XSBundle, XSCoreParamsKey, XSModule}
29import xiangshan.backend.fu.fpu.FPUSubModule
30
31class VFPU(implicit p: Parameters) extends FPUSubModule(p(XSCoreParamsKey).VLEN){
32  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VfpuType.dummy, "VFPU OpType not supported")
33  XSError(io.in.valid && (io.in.bits.uop.ctrl.vconfig.vtype.vsew === 0.U), "8 bits not supported in VFPU")
34  override val dataModule = null // Only use IO, not dataModule
35
36// rename signal
37  val in = io.in.bits
38  val ctrl = io.in.bits.uop.ctrl
39  val vtype = ctrl.vconfig.vtype
40  val src1Type = io.in.bits.uop.ctrl.srcType
41
42// def some signal
43  val fflagsReg = RegInit(0.U(5.W))
44  val fflagsWire = WireInit(0.U(5.W))
45  val dataReg = Reg(io.out.bits.data.cloneType)
46  val dataWire = Wire(dataReg.cloneType)
47  val s_idle :: s_compute :: s_finish :: Nil = Enum(3)
48  val state = RegInit(s_idle)
49  val vfalu = Module(new VfaluWrapper()(p))
50  val vfmacc = Module(new VfmaccWrapper()(p))
51  val vfdiv = Module(new VfdivWrapper()(p))
52  val outValid = vfalu.io.out.valid || vfmacc.io.out.valid || vfdiv.io.out.valid
53  val outFire = vfalu.io.out.fire() || vfmacc.io.out.fire() || vfdiv.io.out.fire()
54
55// reg input signal
56  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
57  val s0_maskReg = Reg(UInt(8.W))
58  val inHs = io.in.fire()
59  when(inHs && state===s_idle){
60    s0_uopReg := io.in.bits.uop
61    s0_maskReg := Fill(8, 1.U(1.W))
62  }
63
64// fsm
65  switch (state) {
66    is (s_idle) {
67      state := Mux(inHs, s_compute, s_idle)
68    }
69    is (s_compute) {
70      state := Mux(outValid, Mux(outFire, s_idle, s_finish),
71                             s_compute)
72    }
73    is (s_finish) {
74      state := Mux(io.out.fire(), s_idle, s_finish)
75    }
76  }
77  fflagsReg := Mux(outValid, fflagsWire, fflagsReg)
78  dataReg := Mux(outValid, dataWire, dataReg)
79
80// connect the input port of vfalu
81  vfalu.io.in.bits.src <> in.src
82  vfalu.io.in.bits.srcType <> in.uop.ctrl.srcType
83  vfalu.io.in.bits.round_mode := rm
84  vfalu.io.in.bits.fp_format := vtype.vsew(1,0)
85  vfalu.io.in.bits.opb_widening := false.B // TODO
86  vfalu.io.in.bits.res_widening := false.B // TODO
87  vfalu.io.in.bits.op_code := ctrl.fuOpType
88  vfalu.io.ready_out.s0_mask := s0_maskReg
89  vfalu.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
90  vfalu.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
91
92//  connect the input port of vfmacc
93  vfmacc.io.in.bits.src <> in.src
94  vfmacc.io.in.bits.srcType <> in.uop.ctrl.srcType
95  vfmacc.io.in.bits.round_mode := rm
96  vfmacc.io.in.bits.fp_format := vtype.vsew(1, 0)
97  vfmacc.io.in.bits.opb_widening := DontCare // TODO
98  vfmacc.io.in.bits.res_widening := false.B // TODO
99  vfmacc.io.in.bits.op_code := DontCare
100  vfmacc.io.ready_out.s0_mask := s0_maskReg
101  vfmacc.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
102  vfmacc.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
103
104  //  connect the input port of vfdiv
105  vfdiv.io.in.bits.src <> in.src
106  vfdiv.io.in.bits.srcType <> in.uop.ctrl.srcType
107  vfdiv.io.in.bits.round_mode := rm
108  vfdiv.io.in.bits.fp_format := vtype.vsew(1, 0)
109  vfdiv.io.in.bits.opb_widening := DontCare // TODO
110  vfdiv.io.in.bits.res_widening := DontCare // TODO
111  vfdiv.io.in.bits.op_code := DontCare
112  vfdiv.io.ready_out.s0_mask := s0_maskReg
113  vfdiv.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
114  vfdiv.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
115
116// connect the output port
117  fflagsWire := LookupTree(s0_uopReg.ctrl.fuOpType, List(
118    VfpuType.fadd  -> vfalu.io.out.bits.fflags,
119    VfpuType.fsub  -> vfalu.io.out.bits.fflags,
120    VfpuType.fmacc -> vfmacc.io.out.bits.fflags,
121    VfpuType.fdiv  -> vfdiv.io.out.bits.fflags,
122  ))
123  fflags := Mux(state === s_compute && outFire, fflagsWire, fflagsReg)
124  dataWire := LookupTree(s0_uopReg.ctrl.fuOpType, List(
125    VfpuType.fadd -> vfalu.io.out.bits.result,
126    VfpuType.fsub -> vfalu.io.out.bits.result,
127    VfpuType.fmacc -> vfmacc.io.out.bits.result,
128    VfpuType.fdiv -> vfdiv.io.out.bits.result,
129  ))
130  io.out.bits.data := Mux(state === s_compute && outFire, dataWire, dataReg)
131  io.out.bits.uop := s0_uopReg
132  // valid/ready
133  vfalu.io.in.valid := io.in.valid && VfpuType.isVfalu(in.uop.ctrl.fuOpType) && state === s_idle
134  vfmacc.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.fmacc && state === s_idle
135  vfdiv.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.fdiv && state === s_idle
136  io.out.valid := state === s_compute && outValid || state === s_finish
137  vfalu.io.out.ready := io.out.ready
138  vfmacc.io.out.ready := io.out.ready
139  vfdiv.io.out.ready := io.out.ready
140  io.in.ready := state === s_idle
141}
142
143class VFPUWraaperBundle (implicit p: Parameters)  extends XSBundle{
144  val in = Flipped(DecoupledIO(Output(new Bundle {
145    val src = Vec(4, Input(UInt(VLEN.W)))
146    val srcType = Vec(4, SrcType())
147
148    val round_mode = UInt(3.W)
149    val fp_format = UInt(2.W) // vsew
150    val opb_widening = Bool()
151    val res_widening = Bool()
152    val op_code = FuOpType()
153  })))
154
155  val ready_out = Input(new Bundle {
156    val s0_mask = UInt((VLEN / 16).W)
157    val s0_sew = UInt(2.W)
158    val s0_vl = UInt(8.W)
159  })
160
161  val out = DecoupledIO(Output(new Bundle {
162    val result = UInt(128.W)
163    val fflags = UInt(5.W)
164  }))
165}
166
167class VfdivWrapper(implicit p: Parameters)  extends XSModule{
168  val Latency = List(5, 7, 12)
169  val AdderWidth = XLEN
170  val NumAdder = VLEN / XLEN
171
172  val io = IO(new VFPUWraaperBundle)
173
174  val in = io.in.bits
175  val out = io.out.bits
176  val inHs = io.in.fire()
177
178  val s0_mask = io.ready_out.s0_mask
179  val s0_sew = io.ready_out.s0_sew
180  val s0_vl = io.ready_out.s0_vl
181
182  val vfdiv = Seq.fill(NumAdder)(Module(new VectorFloatDivider()))
183  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
184  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
185  for (i <- 0 until NumAdder) {
186    vfdiv(i).io.opa_i := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
187    vfdiv(i).io.opb_i := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
188    vfdiv(i).io.is_vec_i := true.B // If you can enter, it must be vector
189    vfdiv(i).io.rm_i := in.round_mode
190    vfdiv(i).io.fp_format_i := Mux(inHs, in.fp_format, 3.U(2.W))
191    vfdiv(i).io.start_valid_i := io.in.valid
192    vfdiv(i).io.finish_ready_i := io.out.ready
193    vfdiv(i).io.flush_i := false.B  // TODO
194  }
195
196  val s4_fflagsVec = VecInit(vfdiv.map(_.io.fflags_o)).asUInt()
197  val s4_fflags16vl = fflagsGen(s0_mask, s4_fflagsVec, List.range(0, 8))
198  val s4_fflags32vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 1, 4, 5))
199  val s4_fflags64vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 4))
200  val s4_fflags = LookupTree(s0_sew(1, 0), List(
201    "b01".U -> Mux(s0_vl.orR, s4_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
202    "b10".U -> Mux(s0_vl.orR, s4_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
203    "b11".U -> Mux(s0_vl.orR, s4_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
204  ))
205  out.fflags := s4_fflags
206
207  val s4_result = VecInit(vfdiv.map(_.io.fpdiv_res_o)).asUInt()
208  out.result := s4_result
209
210  io.in.ready := VecInit(vfdiv.map(_.io.start_ready_o)).asUInt().andR()
211  io.out.valid := VecInit(vfdiv.map(_.io.finish_valid_o)).asUInt().andR()
212}
213
214class VfmaccWrapper(implicit p: Parameters)  extends XSModule{
215  val Latency = 3
216  val AdderWidth = XLEN
217  val NumAdder = VLEN / XLEN
218
219  val io = IO(new VFPUWraaperBundle)
220
221  val in = io.in.bits
222  val out = io.out.bits
223  val inHs = io.in.fire()
224
225  val validPipe = Seq.fill(Latency)(RegInit(false.B))
226  validPipe.zipWithIndex.foreach {
227    case (valid, idx) =>
228      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
229      valid := _valid
230  }
231  val s0_mask = io.ready_out.s0_mask
232  val s0_sew = io.ready_out.s0_sew
233  val s0_vl = io.ready_out.s0_vl
234
235  val vfmacc = Seq.fill(NumAdder)(Module(new VectorFloatFMA()))
236  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
237  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
238  val src3 = Mux(in.srcType(2) === SrcType.vp, in.src(2), VecExtractor(in.fp_format, in.src(2)))
239  for (i <- 0 until NumAdder) {
240    vfmacc(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
241    vfmacc(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
242    vfmacc(i).io.fp_c := Mux(inHs, src3(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
243    vfmacc(i).io.is_vec := true.B // If you can enter, it must be vector
244    vfmacc(i).io.round_mode := in.round_mode
245    vfmacc(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
246    vfmacc(i).io.res_widening := in.res_widening // TODO
247  }
248
249  // output signal generation
250  val s2_fflagsVec = VecInit(vfmacc.map(_.io.fflags)).asUInt()
251  val s2_fflags16vl = fflagsGen(s0_mask, s2_fflagsVec, List.range(0, 8))
252  val s2_fflags32vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 1, 4, 5))
253  val s2_fflags64vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 4))
254  val s2_fflags = LookupTree(s0_sew(1, 0), List(
255    "b01".U -> Mux(s0_vl.orR, s2_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
256    "b10".U -> Mux(s0_vl.orR, s2_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
257    "b11".U -> Mux(s0_vl.orR, s2_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
258  ))
259  out.fflags := s2_fflags
260
261  val s2_result = VecInit(vfmacc.map(_.io.fp_result)).asUInt()
262  out.result := s2_result
263
264  io.in.ready := !(validPipe.foldLeft(false.B)(_ | _)) && io.out.ready
265  io.out.valid := validPipe(Latency - 1)
266}
267
268class VfaluWrapper(implicit p: Parameters)  extends XSModule{
269  val Latency = 2
270  val AdderWidth = XLEN
271  val NumAdder = VLEN / XLEN
272
273  val io = IO(new VFPUWraaperBundle)
274
275  val in = io.in.bits
276  val out = io.out.bits
277  val inHs = io.in.fire()
278
279  // reg input signal
280  val validPipe = Seq.fill(Latency)(RegInit(false.B))
281  validPipe.zipWithIndex.foreach {
282    case (valid, idx) =>
283      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
284      valid := _valid
285  }
286  val s0_mask = io.ready_out.s0_mask
287  val s0_sew = io.ready_out.s0_sew
288  val s0_vl = io.ready_out.s0_vl
289
290  // connect the input signal
291  val vfalu = Seq.fill(NumAdder)(Module(new VectorFloatAdder()))
292  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
293  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
294  for (i <- 0 until NumAdder) {
295    vfalu(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
296    vfalu(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
297    vfalu(i).io.is_vec := true.B // If you can enter, it must be vector
298    vfalu(i).io.round_mode := in.round_mode
299    vfalu(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
300    vfalu(i).io.opb_widening := in.opb_widening // TODO
301    vfalu(i).io.res_widening := in.res_widening // TODO
302    vfalu(i).io.op_code := in.op_code
303  }
304
305  // output signal generation
306  val s0_fflagsVec = VecInit(vfalu.map(_.io.fflags)).asUInt()
307  val s0_fflags16vl = fflagsGen(s0_mask, s0_fflagsVec, List.range(0, 8))
308  val s0_fflags32vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 1, 4, 5))
309  val s0_fflags64vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 4))
310  val s0_fflags = LookupTree(s0_sew(1, 0), List(
311    "b01".U -> Mux(s0_vl.orR, s0_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
312    "b10".U -> Mux(s0_vl.orR, s0_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
313    "b11".U -> Mux(s0_vl.orR, s0_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
314  ))
315  val s1_fflags = RegEnable(s0_fflags, validPipe(Latency-2))
316  out.fflags := s1_fflags
317
318  val s0_result = LookupTree(s0_sew(1, 0), List(
319    "b01".U -> VecInit(vfalu.map(_.io.fp_f16_result)).asUInt(),
320    "b10".U -> VecInit(vfalu.map(_.io.fp_f32_result)).asUInt(),
321    "b11".U -> VecInit(vfalu.map(_.io.fp_f64_result)).asUInt(),
322  ))
323  val s1_result = RegEnable(s0_result, validPipe(Latency-2))
324  out.result := s1_result
325
326  io.in.ready := !(validPipe.foldLeft(false.B)(_|_)) && io.out.ready
327  io.out.valid := validPipe(Latency-1)
328}
329
330object fflagsGen{
331  def fflagsGen(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
332    var num = idx.length
333    val fflags = Seq.fill(num)(Wire(UInt(5.W)))
334    fflags.zip(vmask(num-1, 0).asBools().reverse).zip(idx).foreach {
335      case ((fflags0, mask), id) =>
336        fflags0 := Mux(mask, fflagsResult(id*5+4,id*5+0), 0.U)
337    }
338    val fflagsVl = Wire(Vec(num,UInt(5.W)))
339    for (i <- 0 until num) {
340      val _fflags = if (i == 0) fflags(i) else (fflagsVl(i - 1) | fflags(i))
341      fflagsVl(i) := _fflags
342    }
343    fflagsVl
344  }
345
346  def apply(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
347    fflagsGen(vmask, fflagsResult, idx)
348  }
349}
350
351object VecExtractor{
352  def xf2v_sew(sew: UInt, xf:UInt): UInt = {
353    LookupTree(sew(1, 0), List(
354      "b00".U -> VecInit(Seq.fill(16)(xf(7, 0))).asUInt,
355      "b01".U -> VecInit(Seq.fill(8)(xf(15, 0))).asUInt,
356      "b10".U -> VecInit(Seq.fill(4)(xf(31, 0))).asUInt,
357      "b11".U -> VecInit(Seq.fill(2)(xf(63, 0))).asUInt,
358    ))
359  }
360
361  def apply(sew: UInt, xf: UInt): UInt = {
362    xf2v_sew(sew, xf)
363  }
364}