xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VIPU.scala (revision f9cac32fe90cbdfc694fb763fdd276b7809efc48)
1/****************************************************************************************
2 * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3 * Copyright (c) 2020-2021 Peng Cheng Laboratory
4 *
5 * XiangShan is licensed under Mulan PSL v2.
6 * You can use this software according to the terms and conditions of the Mulan PSL v2.
7 * You may obtain a copy of Mulan PSL v2 at:
8 *          http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 *
14 * See the Mulan PSL v2 for more details.
15 ****************************************************************************************
16 */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3._
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.alu.{VAluOpcode, VIAlu}
27import yunsuan.{VectorElementFormat, VipuType}
28import xiangshan.{SelImm, SrcType, UopDivType, XSCoreParamsKey, XSModule}
29
30import scala.collection.Seq
31
32class VIPU(implicit p: Parameters) extends VPUSubModule(p(XSCoreParamsKey).VLEN) {
33  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VipuType.dummy, "VIPU OpType not supported")
34
35// extra io
36  val vxrm = IO(Input(UInt(2.W)))
37  val vxsat = IO(Output(UInt(1.W)))
38
39// def some signal
40  val dataReg = Reg(io.out.bits.data.cloneType)
41  val dataWire = Wire(dataReg.cloneType)
42  val s_idle :: s_compute :: s_finish :: Nil = Enum(3)
43  val state = RegInit(s_idle)
44  val vialu = Module(new VIAluWrapper)
45  val outValid = vialu.io.out.valid
46  val outFire = vialu.io.out.fire()
47
48// reg input signal
49  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
50  val inHs = io.in.fire()
51  when(inHs && state === s_idle){
52    s0_uopReg := io.in.bits.uop
53  }
54  dataReg := Mux(outValid, dataWire, dataReg)
55
56// fsm
57  switch (state) {
58    is (s_idle) {
59      state := Mux(inHs, s_compute, s_idle)
60    }
61    is (s_compute) {
62      state := Mux(outValid, Mux(outFire, s_idle, s_finish),
63                             s_compute)
64    }
65    is (s_finish) {
66      state := Mux(io.out.fire(), s_idle, s_finish)
67    }
68  }
69
70// connect VIAlu
71  dataWire := vialu.io.out.bits.data
72  vialu.io.in.bits <> io.in.bits
73  vialu.io.redirectIn := DontCare  // TODO :
74  vialu.vxrm := vxrm
75  io.out.bits.data :=  Mux(state === s_compute && outFire, dataWire, dataReg)
76  io.out.bits.uop := s0_uopReg
77  vxsat := vialu.vxsat
78
79  vialu.io.in.valid := io.in.valid && state === s_idle
80  io.out.valid := state === s_compute && outValid || state === s_finish
81  vialu.io.out.ready := io.out.ready
82  io.in.ready := state === s_idle
83}
84
85class VIAluDecodeResultBundle extends Bundle {
86  val opcode = UInt(6.W)
87  val srcType = Vec(2, UInt(4.W))
88  val vdType = UInt(4.W)
89}
90
91class VIAluDecoder (implicit p: Parameters) extends XSModule {
92  val io = IO(new Bundle{
93    val in = Input(new Bundle{
94      val fuOpType = UInt(8.W)
95      val sew = UInt(2.W)
96    })
97    val out = Output(new VIAluDecodeResultBundle)
98  })
99
100//  val DecodeDefault = List(VAluOpcode.dummy,  VpuDataType.dummy, VpuDataType.dummy, VpuDataType.dummy)
101//  val DecodeTable = Array(
102//    BitPat("b" + Cat(VipuType.add, "b00".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s8, VpuDataType.s8, VpuDataType.s8),
103//    BitPat("b" + Cat(VipuType.add, "b01".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s16, VpuDataType.s16, VpuDataType.s16),
104//    BitPat("b" + Cat(VipuType.add, "b10".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s32, VpuDataType.s32, VpuDataType.s32),
105//    BitPat("b" + Cat(VipuType.add, "b11".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s64, VpuDataType.s64, VpuDataType.s64),
106//  )
107//  val opcode :: srcType1 :: srcType2 :: vdType :: Nil = ListLookup(Cat(io.in.fuOpType, io.in.sew), DecodeDefault, DecodeTable)
108
109  // u 00 s 01 f 10 mask 1111
110  val uSew = Cat(0.U(2.W), io.in.sew)
111  val uSew2 = Cat(0.U(2.W), (io.in.sew+1.U))
112  val uSewf2 = Cat(0.U(2.W), (io.in.sew-1.U))
113  val uSewf4 = Cat(0.U(2.W), (io.in.sew-2.U))
114  val uSewf8 = Cat(0.U(2.W), (io.in.sew-3.U))
115  val sSew = Cat(1.U(2.W), io.in.sew)
116  val sSew2 = Cat(1.U(2.W), (io.in.sew+1.U))
117  val sSewf2 = Cat(1.U(2.W), (io.in.sew - 1.U))
118  val sSewf4 = Cat(1.U(2.W), (io.in.sew - 2.U))
119  val sSewf8 = Cat(1.U(2.W), (io.in.sew - 3.U))
120  val mask = "b1111".U(4.W)
121
122  val out = LookupTree(io.in.fuOpType, List(
123    // --------------------- opcode       srcType(0) (1) vdType
124    VipuType.vadd_vv -> Cat(VAluOpcode.vadd, uSew, uSew, uSew).asUInt(),
125    VipuType.vsub_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew).asUInt(),
126    VipuType.vrsub_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew).asUInt(),
127
128    VipuType.vwaddu_vv -> Cat(VAluOpcode.vadd, uSew, uSew, uSew2).asUInt(),
129    VipuType.vwsubu_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew2).asUInt(),
130    VipuType.vwadd_vv -> Cat(VAluOpcode.vadd, sSew, sSew, sSew2).asUInt(),
131    VipuType.vwsub_vv -> Cat(VAluOpcode.vsub, sSew, sSew, sSew2).asUInt(),
132    VipuType.vwaddu_wv -> Cat(VAluOpcode.vadd, uSew2, uSew, uSew2).asUInt(),
133    VipuType.vwsubu_wv -> Cat(VAluOpcode.vsub, uSew2, uSew, uSew2).asUInt(),
134    VipuType.vwadd_wv -> Cat(VAluOpcode.vadd, sSew2, sSew, sSew2).asUInt(),
135    VipuType.vwsub_wv -> Cat(VAluOpcode.vsub, sSew2, sSew, sSew2).asUInt(),
136
137    VipuType.vzext_vf2 -> Cat(VAluOpcode.vext, uSewf2, uSewf2, uSew).asUInt(),
138    VipuType.vsext_vf2 -> Cat(VAluOpcode.vext, sSewf2, sSewf2, sSew).asUInt(),
139    VipuType.vzext_vf4 -> Cat(VAluOpcode.vext, uSewf4, uSewf4, uSew).asUInt(),
140    VipuType.vsext_vf4 -> Cat(VAluOpcode.vext, sSewf4, sSewf4, sSew).asUInt(),
141    VipuType.vzext_vf8 -> Cat(VAluOpcode.vext, uSewf8, uSewf8, uSew).asUInt(),
142    VipuType.vsext_vf8 -> Cat(VAluOpcode.vext, sSewf8, sSewf8, sSew).asUInt(),
143
144    VipuType.vadc_vvm -> Cat(VAluOpcode.vadc, uSew, uSew, uSew).asUInt(),
145    VipuType.vmadc_vvm -> Cat(VAluOpcode.vmadc, uSew, uSew, mask).asUInt(),
146    VipuType.vmadc_vv -> Cat(VAluOpcode.vmadc, uSew, uSew, mask).asUInt(),
147
148    VipuType.vsbc_vvm -> Cat(VAluOpcode.vsbc, uSew, uSew, uSew).asUInt(),
149    VipuType.vmsbc_vvm -> Cat(VAluOpcode.vsbc, uSew, uSew, mask).asUInt(),
150    VipuType.vmsbc_vv -> Cat(VAluOpcode.vsbc, uSew, uSew, mask).asUInt(),
151
152    VipuType.vand_vv -> Cat(VAluOpcode.vand, uSew, uSew, uSew).asUInt(),
153    VipuType.vor_vv -> Cat(VAluOpcode.vor, uSew, uSew, uSew).asUInt(),
154    VipuType.vxor_vv -> Cat(VAluOpcode.vxor, uSew, uSew, uSew).asUInt(),
155
156    VipuType.vsll_vv -> Cat(VAluOpcode.vsll, uSew, uSew, uSew).asUInt(),
157    VipuType.vsrl_vv -> Cat(VAluOpcode.vsrl, uSew, uSew, uSew).asUInt(),
158    VipuType.vsra_vv -> Cat(VAluOpcode.vsra, uSew, uSew, uSew).asUInt(),
159
160    VipuType.vnsrl_wv -> Cat(VAluOpcode.vsrl, uSew2, uSew, uSew).asUInt(),
161    VipuType.vnsra_wv -> Cat(VAluOpcode.vsra, uSew2, uSew, uSew).asUInt(),
162
163    VipuType.vmseq_vv -> Cat(VAluOpcode.vmseq, uSew, uSew, mask).asUInt(),
164    VipuType.vmsne_vv -> Cat(VAluOpcode.vmsne, uSew, uSew, mask).asUInt(),
165    VipuType.vmsltu_vv -> Cat(VAluOpcode.vmslt, uSew, uSew, mask).asUInt(),
166    VipuType.vmslt_vv -> Cat(VAluOpcode.vmslt, sSew, sSew, mask).asUInt(),
167    VipuType.vmsleu_vv -> Cat(VAluOpcode.vmsle, uSew, uSew, mask).asUInt(),
168    VipuType.vmsle_vv -> Cat(VAluOpcode.vmsle, sSew, sSew, mask).asUInt(),
169    VipuType.vmsgtu_vv -> Cat(VAluOpcode.vmsgt, uSew, uSew, mask).asUInt(),
170    VipuType.vmsgt_vv -> Cat(VAluOpcode.vmsgt, sSew, sSew, mask).asUInt(),
171
172    VipuType.vminu_vv -> Cat(VAluOpcode.vmin, uSew, uSew, uSew).asUInt(),
173    VipuType.vmin_vv -> Cat(VAluOpcode.vmin, sSew, sSew, sSew).asUInt(),
174    VipuType.vmaxu_vv -> Cat(VAluOpcode.vmax, uSew, uSew, uSew).asUInt(),
175    VipuType.vmax_vv -> Cat(VAluOpcode.vmax, sSew, sSew, sSew).asUInt(),
176
177    VipuType.vmerge_vvm -> Cat(VAluOpcode.vmerge, uSew, uSew, mask).asUInt(),
178
179    VipuType.vmv_v_v -> Cat(VAluOpcode.vmv, uSew, uSew, uSew).asUInt(),
180
181    VipuType.vsaddu_vv -> Cat(VAluOpcode.vsadd, uSew, uSew, uSew).asUInt(),
182    VipuType.vsadd_vv -> Cat(VAluOpcode.vsadd, sSew, sSew, sSew).asUInt(),
183    VipuType.vssubu_vv -> Cat(VAluOpcode.vssub, uSew, uSew, uSew).asUInt(),
184    VipuType.vssub_vv -> Cat(VAluOpcode.vssub, sSew, sSew, sSew).asUInt(),
185
186    VipuType.vaaddu_vv -> Cat(VAluOpcode.vaadd, uSew, uSew, uSew).asUInt(),
187    VipuType.vaadd_vv -> Cat(VAluOpcode.vaadd, sSew, sSew, sSew).asUInt(),
188    VipuType.vasubu_vv -> Cat(VAluOpcode.vasub, uSew, uSew, uSew).asUInt(),
189    VipuType.vasub_vv -> Cat(VAluOpcode.vasub, sSew, sSew, sSew).asUInt(),
190
191    VipuType.vssrl_vv -> Cat(VAluOpcode.vssrl, uSew, uSew, uSew).asUInt(),
192    VipuType.vssra_vv -> Cat(VAluOpcode.vssra, uSew, uSew, uSew).asUInt(),
193
194    VipuType.vnclipu_wv -> Cat(VAluOpcode.vssrl, uSew2, uSew, uSew).asUInt(),
195    VipuType.vnclip_wv -> Cat(VAluOpcode.vssra, uSew2, uSew, uSew).asUInt(),
196
197    VipuType.vredsum_vs -> Cat(VAluOpcode.vredsum, uSew, uSew, uSew).asUInt(),
198    VipuType.vredmaxu_vs -> Cat(VAluOpcode.vredmax, uSew, uSew, uSew).asUInt(),
199    VipuType.vredmax_vs -> Cat(VAluOpcode.vredmax, sSew, sSew, sSew).asUInt(),
200    VipuType.vredminu_vs -> Cat(VAluOpcode.vredmin, uSew, uSew, uSew).asUInt(),
201    VipuType.vredmin_vs -> Cat(VAluOpcode.vredmin, sSew, sSew, sSew).asUInt(),
202    VipuType.vredand_vs -> Cat(VAluOpcode.vredand, uSew, uSew, uSew).asUInt(),
203    VipuType.vredor_vs -> Cat(VAluOpcode.vredor, uSew, uSew, uSew).asUInt(),
204    VipuType.vredxor_vs -> Cat(VAluOpcode.vredxor, uSew, uSew, uSew).asUInt(),
205
206    VipuType.vwredsumu_vs -> Cat(VAluOpcode.vredsum, uSew, uSew, uSew2).asUInt(),
207    VipuType.vwredsum_vs -> Cat(VAluOpcode.vredsum, sSew, sSew, sSew2).asUInt(),
208
209    VipuType.vmand_mm -> Cat(VAluOpcode.vand, mask, mask, mask).asUInt(),
210    VipuType.vmnand_mm -> Cat(VAluOpcode.vnand, mask, mask, mask).asUInt(),
211    VipuType.vmandn_mm -> Cat(VAluOpcode.vandn, mask, mask, mask).asUInt(),
212    VipuType.vmxor_mm -> Cat(VAluOpcode.vxor, mask, mask, mask).asUInt(),
213    VipuType.vmor_mm -> Cat(VAluOpcode.vor, mask, mask, mask).asUInt(),
214    VipuType.vmnor_mm -> Cat(VAluOpcode.vnor, mask, mask, mask).asUInt(),
215    VipuType.vmorn_mm -> Cat(VAluOpcode.vorn, mask, mask, mask).asUInt(),
216    VipuType.vmxnor_mm -> Cat(VAluOpcode.vxnor, mask, mask, mask).asUInt(),
217
218    VipuType.vcpop_m -> Cat(VAluOpcode.vcpop, mask, mask, mask).asUInt(),
219    VipuType.vfirst_m -> Cat(VAluOpcode.vfirst, mask, mask, mask).asUInt(),
220    VipuType.vmsbf_m -> Cat(VAluOpcode.vmsbf, mask, mask, mask).asUInt(),
221    VipuType.vmsif_m -> Cat(VAluOpcode.vmsif, mask, mask, mask).asUInt(),
222    VipuType.vmsof_m -> Cat(VAluOpcode.vmsof, mask, mask, mask).asUInt(),
223
224    VipuType.viota_m -> Cat(VAluOpcode.viota, mask, mask, uSew).asUInt(),
225    VipuType.vid_v -> Cat(VAluOpcode.vid, uSew, uSew, uSew).asUInt(),
226
227  )).asTypeOf(new VIAluDecodeResultBundle)
228
229  io.out <> out
230}
231
232class VIAluWrapper(implicit p: Parameters)  extends VPUSubModule(p(XSCoreParamsKey).VLEN) {
233  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VipuType.dummy, "VIPU OpType not supported")
234
235// extra io
236  val vxrm = IO(Input(UInt(2.W)))
237  val vxsat = IO(Output(UInt(1.W)))
238
239// rename signal
240  val in = io.in.bits
241  val ctrl = in.uop.ctrl
242  val vtype = ctrl.vconfig.vtype
243
244// generate src1 and src2
245  val imm = VecInit(Seq.fill(VLEN/XLEN)(VecImmExtractor(ctrl.selImm, vtype.vsew, ctrl.imm))).asUInt
246  val _vs1 = Mux(SrcType.isImm(ctrl.srcType(0)), imm, Mux(ctrl.uopDivType === UopDivType.VEC_MV_LMUL, VecExtractor(vtype.vsew, io.in.bits.src(0)), io.in.bits.src(0)))
247  val _vs2 = in.src(1)
248  val vs1 = Mux(VipuType.needReverse(ctrl.fuOpType), _vs2, _vs1)
249  val vs2 = Mux(VipuType.needReverse(ctrl.fuOpType), _vs1, _vs2)
250  val mask = Mux(VipuType.needClearMask(ctrl.fuOpType), 0.U, in.src(3))
251
252// connect VIAlu
253  val decoder = Module(new VIAluDecoder)
254  val vialu = Module(new VIAlu)
255  decoder.io.in.fuOpType := in.uop.ctrl.fuOpType
256  decoder.io.in.sew := in.uop.ctrl.vconfig.vtype.vsew(1,0)
257
258  vialu.io.in.bits.opcode := decoder.io.out.opcode
259  vialu.io.in.bits.info.vm := in.uop.ctrl.vm
260  vialu.io.in.bits.info.ma := in.uop.ctrl.vconfig.vtype.vma
261  vialu.io.in.bits.info.ta := in.uop.ctrl.vconfig.vtype.vta
262  vialu.io.in.bits.info.vlmul := in.uop.ctrl.vconfig.vtype.vlmul
263  vialu.io.in.bits.info.vl := in.uop.ctrl.vconfig.vl
264
265  vialu.io.in.bits.info.vstart := 0.U // TODO :
266  vialu.io.in.bits.info.uopIdx := in.uop.ctrl.uopIdx
267
268  vialu.io.in.bits.info.vxrm := vxrm
269  vialu.io.in.bits.srcType(0) := decoder.io.out.srcType(0)
270  vialu.io.in.bits.srcType(1) := decoder.io.out.srcType(1)
271  vialu.io.in.bits.vdType := decoder.io.out.vdType
272  vialu.io.in.bits.vs1 := vs1
273  vialu.io.in.bits.vs2 := vs2
274  vialu.io.in.bits.old_vd := in.src(2)
275  vialu.io.in.bits.mask := mask
276
277  val vdOut = vialu.io.out.bits.vd
278  val vxsatOut = vialu.io.out.bits.vxsat
279
280  vialu.io.in.valid := io.in.valid
281
282// connect io
283  io.out.bits.data := vdOut
284  io.out.bits.uop := DontCare
285  vxsat := vxsatOut
286  io.out.valid := vialu.io.out.valid
287  io.in.ready := DontCare
288}
289
290object VecImmExtractor {
291  def Imm_OPIVIS(imm: UInt): UInt = {
292    SignExt(imm(4,0), 8)
293  }
294  def Imm_OPIVIU(imm: UInt): UInt = {
295    ZeroExt(imm(4,0), 8)
296  }
297
298  def imm_sew(sew: UInt, imm: UInt): UInt = {
299    val _imm = SignExt(imm(7,0), 64)
300    LookupTree(sew(1,0), List(
301      "b00".U -> VecInit(Seq.fill(8)(_imm(7,0))).asUInt,
302      "b01".U -> VecInit(Seq.fill(4)(_imm(15,0))).asUInt,
303      "b10".U -> VecInit(Seq.fill(2)(_imm(31,0))).asUInt,
304      "b11".U -> _imm(63,0),
305    ))
306  }
307
308  def apply(immType: UInt, sew: UInt, imm: UInt): UInt = {
309    val _imm = Mux(immType === SelImm.IMM_OPIVIS, Imm_OPIVIS(imm), Imm_OPIVIU(imm))
310    imm_sew(sew, _imm(7,0))
311  }
312}
313