xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VFPU.scala (revision 757024a1f2cbf425bb3460f7409aa6a801e4f5d5)
1/****************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ****************************************************************************************
16  */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3.{Mux, _}
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.VectorFloatAdder
27import yunsuan.VfpuType
28import xiangshan.{SrcType, XSCoreParamsKey, XSModule, FuOpType}
29import xiangshan.backend.fu.fpu.FPUSubModule
30
31class VFPU(implicit p: Parameters) extends FPUSubModule(p(XSCoreParamsKey).VLEN){
32  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VfpuType.dummy, "VFPU OpType not supported")
33  XSError(io.in.valid && (io.in.bits.uop.ctrl.vconfig.vtype.vsew === 0.U), "8 bits not supported in VFPU")
34  override val dataModule = null // Only use IO, not dataModule
35
36// rename signal
37  val in = io.in.bits
38  val ctrl = io.in.bits.uop.ctrl
39  val vtype = ctrl.vconfig.vtype
40  val src1Type = io.in.bits.uop.ctrl.srcType
41
42// reg input signal
43  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
44  val s0_maskReg = Reg(UInt(8.W))
45  val inHs = io.in.fire()
46  when(inHs){
47    s0_uopReg := io.in.bits.uop
48    s0_maskReg := Fill(8, 1.U(1.W))
49  }
50
51// connect the input port of vfalu
52  val vfalu = Module(new VfaluWrapper()(p))
53  vfalu.io.in.bits.src <> in.src
54  vfalu.io.in.bits.srcType <> in.uop.ctrl.srcType
55  vfalu.io.in.bits.vmask := Fill(8, 1.U(1.W))
56  vfalu.io.in.bits.vl := in.uop.ctrl.vconfig.vl
57  vfalu.io.in.bits.round_mode := rm
58  vfalu.io.in.bits.fp_format := vtype.vsew(1,0)
59  vfalu.io.in.bits.opb_widening := false.B // TODO
60  vfalu.io.in.bits.res_widening := false.B // TODO
61  vfalu.io.in.bits.op_code := ctrl.fuOpType
62  vfalu.io.ready_out.s0_mask := s0_maskReg
63  vfalu.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
64  vfalu.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
65
66// connect the output port
67  fflags := vfalu.io.out.bits.fflags
68  io.out.bits.data := vfalu.io.out.bits.result
69  io.out.bits.uop := s0_uopReg
70  // valid/ready
71  vfalu.io.in.valid := io.in.valid
72  io.out.valid := vfalu.io.out.valid
73  vfalu.io.out.ready := io.out.ready
74  io.in.ready := vfalu.io.in.ready
75}
76
77class VfaluWrapper(implicit p: Parameters)  extends XSModule{
78  val Latency = 2
79  val AdderWidth = XLEN
80  val NumAdder = VLEN / XLEN
81
82  val io = IO(new Bundle{
83    val in = Flipped(DecoupledIO(Output(new Bundle{
84      val src = Vec(3, Input(UInt(VLEN.W)))
85      val srcType = Vec(4, SrcType())
86      val vmask = UInt((VLEN/16).W)
87      val vl = UInt(8.W)
88
89      val round_mode = UInt(3.W)
90      val fp_format = UInt(2.W) // vsew
91      val opb_widening  = Bool()
92      val res_widening  = Bool()
93      val op_code       = FuOpType()
94    })))
95
96    val ready_out = Input(new Bundle {
97      val s0_mask = UInt((VLEN / 16).W)
98      val s0_sew = UInt(2.W)
99      val s0_vl = UInt(8.W)
100    })
101
102    val out = DecoupledIO(Output(new Bundle{
103      val result = UInt(128.W)
104      val fflags = UInt(5.W)
105    }))
106  })
107
108  val in = io.in.bits
109  val out = io.out.bits
110  val inHs = io.in.fire()
111
112  // reg input signal
113  val validPipe = Seq.fill(Latency)(RegInit(false.B))
114  validPipe.zipWithIndex.foreach {
115    case (valid, idx) =>
116      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
117      valid := _valid
118  }
119  val s0_mask = io.ready_out.s0_mask
120  val s0_sew = io.ready_out.s0_sew
121  val s0_vl = io.ready_out.s0_vl
122
123  // connect the input signal
124  val vfalu = Seq.fill(NumAdder)(Module(new VectorFloatAdder()))
125  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
126  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
127  for (i <- 0 until NumAdder) {
128    vfalu(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
129    vfalu(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
130    vfalu(i).io.is_vec := true.B // If you can enter, it must be vector
131    vfalu(i).io.round_mode := in.round_mode
132    vfalu(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
133    vfalu(i).io.opb_widening := in.opb_widening // TODO
134    vfalu(i).io.res_widening := in.res_widening // TODO
135    vfalu(i).io.op_code := in.op_code
136  }
137
138  // output signal generation
139  val s0_fflagsVec = VecInit(vfalu.map(_.io.fflags)).asUInt()
140  val s0_fflags16vl = fflagsGen(s0_mask, s0_fflagsVec, List.range(0, 8))
141  val s0_fflags32vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 1, 4, 5))
142  val s0_fflags64vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 4))
143  val s0_fflags = LookupTree(s0_sew(1, 0), List(
144    "b01".U -> Mux(s0_vl.orR, s0_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
145    "b10".U -> Mux(s0_vl.orR, s0_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
146    "b11".U -> Mux(s0_vl.orR, s0_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
147  ))
148  val s1_fflags = RegEnable(s0_fflags, validPipe(Latency-2))
149  out.fflags := s1_fflags
150
151  val s0_result = LookupTree(s0_sew(1, 0), List(
152    "b01".U -> VecInit(vfalu.map(_.io.fp_f16_result)).asUInt(),
153    "b10".U -> VecInit(vfalu.map(_.io.fp_f32_result)).asUInt(),
154    "b11".U -> VecInit(vfalu.map(_.io.fp_f64_result)).asUInt(),
155  ))
156  val s1_result = RegEnable(s0_result, validPipe(Latency-2))
157  out.result := s1_result
158
159  io.in.ready := !(validPipe.foldLeft(false.B)(_|_)) && io.out.ready
160  io.out.valid := validPipe(Latency-1)
161}
162
163object fflagsGen{
164  def fflagsGen(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
165    var num = idx.length
166    val fflags = Seq.fill(num)(Wire(UInt(5.W)))
167    fflags.zip(vmask(num-1, 0).asBools().reverse).zip(idx).foreach {
168      case ((fflags0, mask), id) =>
169        fflags0 := Mux(mask, fflagsResult(id*5+4,id*5+0), 0.U)
170    }
171    val fflagsVl = Wire(Vec(num,UInt(5.W)))
172    for (i <- 0 until num) {
173      val _fflags = if (i == 0) fflags(i) else (fflagsVl(i - 1) | fflags(i))
174      fflagsVl(i) := _fflags
175    }
176    fflagsVl
177  }
178
179  def apply(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
180    fflagsGen(vmask, fflagsResult, idx)
181  }
182}
183
184object VecExtractor{
185  def xf2v_sew(sew: UInt, xf:UInt): UInt = {
186    LookupTree(sew(1, 0), List(
187      "b00".U -> VecInit(Seq.fill(16)(xf(7, 0))).asUInt,
188      "b01".U -> VecInit(Seq.fill(8)(xf(15, 0))).asUInt,
189      "b10".U -> VecInit(Seq.fill(4)(xf(31, 0))).asUInt,
190      "b11".U -> VecInit(Seq.fill(2)(xf(63, 0))).asUInt,
191    ))
192  }
193
194  def apply(sew: UInt, xf: UInt): UInt = {
195    xf2v_sew(sew, xf)
196  }
197}