xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VIPU.scala (revision 5d9d92aa25e485da3fc712f597489246ea88b692)
1/****************************************************************************************
2 * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3 * Copyright (c) 2020-2021 Peng Cheng Laboratory
4 *
5 * XiangShan is licensed under Mulan PSL v2.
6 * You can use this software according to the terms and conditions of the Mulan PSL v2.
7 * You may obtain a copy of Mulan PSL v2 at:
8 *          http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 *
14 * See the Mulan PSL v2 for more details.
15 ****************************************************************************************
16 */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3._
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.alu.{VAluOpcode, VIAlu}
27import yunsuan.{VectorElementFormat, VipuType}
28import xiangshan.{SelImm, SrcType, UopDivType, XSCoreParamsKey, XSModule}
29
30import scala.collection.Seq
31
32class VIPU(implicit p: Parameters) extends VPUSubModule(p(XSCoreParamsKey).VLEN) {
33  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VipuType.dummy, "VIPU OpType not supported")
34
35// extra io
36  val vxrm = IO(Input(UInt(2.W)))
37  val vxsat = IO(Output(UInt(1.W)))
38
39// def some signal
40  val dataReg = Reg(io.out.bits.data.cloneType)
41  val dataWire = Wire(dataReg.cloneType)
42  val s_idle :: s_compute :: s_finish :: Nil = Enum(3)
43  val state = RegInit(s_idle)
44  val vialu = Module(new VIAluWrapper)
45  val outValid = vialu.io.out.valid
46  val outFire = vialu.io.out.fire()
47
48// reg input signal
49  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
50  val inHs = io.in.fire()
51  when(inHs && state === s_idle){
52    s0_uopReg := io.in.bits.uop
53  }
54  dataReg := Mux(outValid, dataWire, dataReg)
55
56// fsm
57  switch (state) {
58    is (s_idle) {
59      state := Mux(inHs, s_compute, s_idle)
60    }
61    is (s_compute) {
62      state := Mux(outValid, Mux(outFire, s_idle, s_finish),
63                             s_compute)
64    }
65    is (s_finish) {
66      state := Mux(io.out.fire(), s_idle, s_finish)
67    }
68  }
69
70// connect VIAlu
71  dataWire := vialu.io.out.bits.data
72  vialu.io.in.bits <> io.in.bits
73  vialu.io.redirectIn := DontCare  // TODO :
74  vialu.vxrm := vxrm
75  io.out.bits.data :=  Mux(state === s_compute && outFire, dataWire, dataReg)
76  io.out.bits.uop := s0_uopReg
77  vxsat := vialu.vxsat
78
79  vialu.io.in.valid := io.in.valid && state === s_idle
80  io.out.valid := state === s_compute && outValid || state === s_finish
81  vialu.io.out.ready := io.out.ready
82  io.in.ready := state === s_idle
83}
84
85class VIAluDecodeResultBundle extends Bundle {
86  val opcode = UInt(6.W)
87  val srcType2 = UInt(4.W)
88  val srcType1 = UInt(4.W)
89  val vdType = UInt(4.W)
90}
91
92class VIAluDecoder (implicit p: Parameters) extends XSModule {
93  val io = IO(new Bundle{
94    val in = Input(new Bundle{
95      val fuOpType = UInt(8.W)
96      val sew = UInt(2.W)
97    })
98    val out = Output(new VIAluDecodeResultBundle)
99  })
100
101//  val DecodeDefault = List(VAluOpcode.dummy,  VpuDataType.dummy, VpuDataType.dummy, VpuDataType.dummy)
102//  val DecodeTable = Array(
103//    BitPat("b" + Cat(VipuType.add, "b00".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s8, VpuDataType.s8, VpuDataType.s8),
104//    BitPat("b" + Cat(VipuType.add, "b01".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s16, VpuDataType.s16, VpuDataType.s16),
105//    BitPat("b" + Cat(VipuType.add, "b10".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s32, VpuDataType.s32, VpuDataType.s32),
106//    BitPat("b" + Cat(VipuType.add, "b11".U).litValue().toString()) -> List(VAluOpcode.vadd,  VpuDataType.s64, VpuDataType.s64, VpuDataType.s64),
107//  )
108//  val opcode :: srcType1 :: srcType2 :: vdType :: Nil = ListLookup(Cat(io.in.fuOpType, io.in.sew), DecodeDefault, DecodeTable)
109
110  // u 00 s 01 f 10 mask 1111
111  val uSew = Cat(0.U(2.W), io.in.sew)
112  val uSew2 = Cat(0.U(2.W), (io.in.sew+1.U))
113  val uSewf2 = Cat(0.U(2.W), (io.in.sew-1.U))
114  val uSewf4 = Cat(0.U(2.W), (io.in.sew-2.U))
115  val uSewf8 = Cat(0.U(2.W), (io.in.sew-3.U))
116  val sSew = Cat(1.U(2.W), io.in.sew)
117  val sSew2 = Cat(1.U(2.W), (io.in.sew+1.U))
118  val sSewf2 = Cat(1.U(2.W), (io.in.sew - 1.U))
119  val sSewf4 = Cat(1.U(2.W), (io.in.sew - 2.U))
120  val sSewf8 = Cat(1.U(2.W), (io.in.sew - 3.U))
121  val mask = "b1111".U(4.W)
122
123  val out = LookupTree(io.in.fuOpType, List(
124    // --------------------- opcode       srcType2 1 vdType
125    VipuType.vadd_vv -> Cat(VAluOpcode.vadd, uSew, uSew, uSew).asUInt(),
126    VipuType.vsub_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew).asUInt(),
127    VipuType.vrsub_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew).asUInt(),
128
129    VipuType.vwaddu_vv -> Cat(VAluOpcode.vadd, uSew, uSew, uSew2).asUInt(),
130    VipuType.vwsubu_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew2).asUInt(),
131    VipuType.vwadd_vv -> Cat(VAluOpcode.vadd, sSew, sSew, sSew2).asUInt(),
132    VipuType.vwsub_vv -> Cat(VAluOpcode.vsub, sSew, sSew, sSew2).asUInt(),
133    VipuType.vwaddu_wv -> Cat(VAluOpcode.vadd, uSew2, uSew, uSew2).asUInt(),
134    VipuType.vwsubu_wv -> Cat(VAluOpcode.vsub, uSew2, uSew, uSew2).asUInt(),
135    VipuType.vwadd_wv -> Cat(VAluOpcode.vadd, sSew2, sSew, sSew2).asUInt(),
136    VipuType.vwsub_wv -> Cat(VAluOpcode.vsub, sSew2, sSew, sSew2).asUInt(),
137
138    VipuType.vzext_vf2 -> Cat(VAluOpcode.vext, uSewf2, uSewf2, uSew).asUInt(),
139    VipuType.vsext_vf2 -> Cat(VAluOpcode.vext, sSewf2, sSewf2, sSew).asUInt(),
140    VipuType.vzext_vf4 -> Cat(VAluOpcode.vext, uSewf4, uSewf4, uSew).asUInt(),
141    VipuType.vsext_vf4 -> Cat(VAluOpcode.vext, sSewf4, sSewf4, sSew).asUInt(),
142    VipuType.vzext_vf8 -> Cat(VAluOpcode.vext, uSewf8, uSewf8, uSew).asUInt(),
143    VipuType.vsext_vf8 -> Cat(VAluOpcode.vext, sSewf8, sSewf8, sSew).asUInt(),
144
145    VipuType.vadc_vvm -> Cat(VAluOpcode.vadc, uSew, uSew, uSew).asUInt(),
146    VipuType.vmadc_vvm -> Cat(VAluOpcode.vmadc, uSew, uSew, mask).asUInt(),
147    VipuType.vmadc_vv -> Cat(VAluOpcode.vmadc, uSew, uSew, mask).asUInt(),
148
149    VipuType.vsbc_vvm -> Cat(VAluOpcode.vsbc, uSew, uSew, uSew).asUInt(),
150    VipuType.vmsbc_vvm -> Cat(VAluOpcode.vsbc, uSew, uSew, mask).asUInt(),
151    VipuType.vmsbc_vv -> Cat(VAluOpcode.vsbc, uSew, uSew, mask).asUInt(),
152
153    VipuType.vand_vv -> Cat(VAluOpcode.vand, uSew, uSew, uSew).asUInt(),
154    VipuType.vor_vv -> Cat(VAluOpcode.vor, uSew, uSew, uSew).asUInt(),
155    VipuType.vxor_vv -> Cat(VAluOpcode.vxor, uSew, uSew, uSew).asUInt(),
156
157    VipuType.vsll_vv -> Cat(VAluOpcode.vsll, uSew, uSew, uSew).asUInt(),
158    VipuType.vsrl_vv -> Cat(VAluOpcode.vsrl, uSew, uSew, uSew).asUInt(),
159    VipuType.vsra_vv -> Cat(VAluOpcode.vsra, uSew, uSew, uSew).asUInt(),
160
161    VipuType.vnsrl_wv -> Cat(VAluOpcode.vsrl, uSew2, uSew, uSew).asUInt(),
162    VipuType.vnsra_wv -> Cat(VAluOpcode.vsra, uSew2, uSew, uSew).asUInt(),
163
164    VipuType.vmseq_vv -> Cat(VAluOpcode.vmseq, uSew, uSew, mask).asUInt(),
165    VipuType.vmsne_vv -> Cat(VAluOpcode.vmsne, uSew, uSew, mask).asUInt(),
166    VipuType.vmsltu_vv -> Cat(VAluOpcode.vmslt, uSew, uSew, mask).asUInt(),
167    VipuType.vmslt_vv -> Cat(VAluOpcode.vmslt, sSew, sSew, mask).asUInt(),
168    VipuType.vmsleu_vv -> Cat(VAluOpcode.vmsle, uSew, uSew, mask).asUInt(),
169    VipuType.vmsle_vv -> Cat(VAluOpcode.vmsle, sSew, sSew, mask).asUInt(),
170    VipuType.vmsgtu_vv -> Cat(VAluOpcode.vmsgt, uSew, uSew, mask).asUInt(),
171    VipuType.vmsgt_vv -> Cat(VAluOpcode.vmsgt, sSew, sSew, mask).asUInt(),
172
173    VipuType.vminu_vv -> Cat(VAluOpcode.vmin, uSew, uSew, uSew).asUInt(),
174    VipuType.vmin_vv -> Cat(VAluOpcode.vmin, sSew, sSew, sSew).asUInt(),
175    VipuType.vmaxu_vv -> Cat(VAluOpcode.vmax, uSew, uSew, uSew).asUInt(),
176    VipuType.vmax_vv -> Cat(VAluOpcode.vmax, sSew, sSew, sSew).asUInt(),
177
178    VipuType.vmerge_vvm -> Cat(VAluOpcode.vmerge, uSew, uSew, mask).asUInt(),
179
180    VipuType.vmv_v_v -> Cat(VAluOpcode.vmv, uSew, uSew, uSew).asUInt(),
181
182    VipuType.vsaddu_vv -> Cat(VAluOpcode.vsadd, uSew, uSew, uSew).asUInt(),
183    VipuType.vsadd_vv -> Cat(VAluOpcode.vsadd, sSew, sSew, sSew).asUInt(),
184    VipuType.vssubu_vv -> Cat(VAluOpcode.vssub, uSew, uSew, uSew).asUInt(),
185    VipuType.vssub_vv -> Cat(VAluOpcode.vssub, sSew, sSew, sSew).asUInt(),
186
187    VipuType.vaaddu_vv -> Cat(VAluOpcode.vaadd, uSew, uSew, uSew).asUInt(),
188    VipuType.vaadd_vv -> Cat(VAluOpcode.vaadd, sSew, sSew, sSew).asUInt(),
189    VipuType.vasubu_vv -> Cat(VAluOpcode.vasub, uSew, uSew, uSew).asUInt(),
190    VipuType.vasub_vv -> Cat(VAluOpcode.vasub, sSew, sSew, sSew).asUInt(),
191
192    VipuType.vssrl_vv -> Cat(VAluOpcode.vssrl, uSew, uSew, uSew).asUInt(),
193    VipuType.vssra_vv -> Cat(VAluOpcode.vssra, uSew, uSew, uSew).asUInt(),
194
195    VipuType.vnclipu_wv -> Cat(VAluOpcode.vssrl, uSew2, uSew, uSew).asUInt(),
196    VipuType.vnclip_wv -> Cat(VAluOpcode.vssra, uSew2, uSew, uSew).asUInt(),
197
198    VipuType.vredsum_vs -> Cat(VAluOpcode.vredsum, uSew, uSew, uSew).asUInt(),
199    VipuType.vredmaxu_vs -> Cat(VAluOpcode.vredmax, uSew, uSew, uSew).asUInt(),
200    VipuType.vredmax_vs -> Cat(VAluOpcode.vredmax, sSew, sSew, sSew).asUInt(),
201    VipuType.vredminu_vs -> Cat(VAluOpcode.vredmin, uSew, uSew, uSew).asUInt(),
202    VipuType.vredmin_vs -> Cat(VAluOpcode.vredmin, sSew, sSew, sSew).asUInt(),
203    VipuType.vredand_vs -> Cat(VAluOpcode.vredand, uSew, uSew, uSew).asUInt(),
204    VipuType.vredor_vs -> Cat(VAluOpcode.vredor, uSew, uSew, uSew).asUInt(),
205    VipuType.vredxor_vs -> Cat(VAluOpcode.vredxor, uSew, uSew, uSew).asUInt(),
206
207    VipuType.vwredsumu_vs -> Cat(VAluOpcode.vredsum, uSew, uSew, uSew2).asUInt(),
208    VipuType.vwredsum_vs -> Cat(VAluOpcode.vredsum, sSew, sSew, sSew2).asUInt(),
209
210    VipuType.vmand_mm -> Cat(VAluOpcode.vand, mask, mask, mask).asUInt(),
211    VipuType.vmnand_mm -> Cat(VAluOpcode.vnand, mask, mask, mask).asUInt(),
212    VipuType.vmandn_mm -> Cat(VAluOpcode.vandn, mask, mask, mask).asUInt(),
213    VipuType.vmxor_mm -> Cat(VAluOpcode.vxor, mask, mask, mask).asUInt(),
214    VipuType.vmor_mm -> Cat(VAluOpcode.vor, mask, mask, mask).asUInt(),
215    VipuType.vmnor_mm -> Cat(VAluOpcode.vnor, mask, mask, mask).asUInt(),
216    VipuType.vmorn_mm -> Cat(VAluOpcode.vorn, mask, mask, mask).asUInt(),
217    VipuType.vmxnor_mm -> Cat(VAluOpcode.vxnor, mask, mask, mask).asUInt(),
218
219    VipuType.vcpop_m -> Cat(VAluOpcode.vcpop, mask, mask, mask).asUInt(),
220    VipuType.vfirst_m -> Cat(VAluOpcode.vfirst, mask, mask, mask).asUInt(),
221    VipuType.vmsbf_m -> Cat(VAluOpcode.vmsbf, mask, mask, mask).asUInt(),
222    VipuType.vmsif_m -> Cat(VAluOpcode.vmsif, mask, mask, mask).asUInt(),
223    VipuType.vmsof_m -> Cat(VAluOpcode.vmsof, mask, mask, mask).asUInt(),
224
225    VipuType.viota_m -> Cat(VAluOpcode.viota, mask, mask, uSew).asUInt(),
226    VipuType.vid_v -> Cat(VAluOpcode.vid, uSew, uSew, uSew).asUInt(),
227
228  )).asTypeOf(new VIAluDecodeResultBundle)
229
230  io.out <> out
231}
232
233class VIAluWrapper(implicit p: Parameters)  extends VPUSubModule(p(XSCoreParamsKey).VLEN) {
234  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VipuType.dummy, "VIPU OpType not supported")
235
236// extra io
237  val vxrm = IO(Input(UInt(2.W)))
238  val vxsat = IO(Output(UInt(1.W)))
239
240// rename signal
241  val in = io.in.bits
242  val ctrl = in.uop.ctrl
243  val vtype = ctrl.vconfig.vtype
244
245// generate src1 and src2
246  val imm = VecInit(Seq.fill(VLEN/XLEN)(VecImmExtractor(ctrl.selImm, vtype.vsew, ctrl.imm))).asUInt
247  val _vs1 = Mux(SrcType.isImm(ctrl.srcType(0)), imm, Mux(ctrl.uopDivType === UopDivType.VEC_MV_LMUL || ctrl.uopDivType === UopDivType.VEC_MV_WIDE || ctrl.uopDivType === UopDivType.VEC_MV_WIDE0 || ctrl.uopDivType === UopDivType.VEC_MV_NARROW, VecExtractor(vtype.vsew, io.in.bits.src(0)), io.in.bits.src(0)))
248  val _vs2 = in.src(1)
249  val vs1 = Mux(VipuType.needReverse(ctrl.fuOpType), _vs2, _vs1)
250  val vs2 = Mux(VipuType.needReverse(ctrl.fuOpType), _vs1, _vs2)
251  val mask = Mux(VipuType.needClearMask(ctrl.fuOpType), 0.U, in.src(3))
252
253// connect VIAlu
254  val decoder = Module(new VIAluDecoder)
255  val vialu = Module(new VIAlu)
256  decoder.io.in.fuOpType := in.uop.ctrl.fuOpType
257  decoder.io.in.sew := in.uop.ctrl.vconfig.vtype.vsew(1,0)
258
259  vialu.io.in.bits.opcode := decoder.io.out.opcode
260  vialu.io.in.bits.info.vm := in.uop.ctrl.vm
261  vialu.io.in.bits.info.ma := in.uop.ctrl.vconfig.vtype.vma
262  vialu.io.in.bits.info.ta := in.uop.ctrl.vconfig.vtype.vta
263  vialu.io.in.bits.info.vlmul := in.uop.ctrl.vconfig.vtype.vlmul
264  vialu.io.in.bits.info.vl := in.uop.ctrl.vconfig.vl
265
266  vialu.io.in.bits.info.vstart := 0.U // TODO :
267  vialu.io.in.bits.info.uopIdx := in.uop.ctrl.uopIdx
268
269  vialu.io.in.bits.info.vxrm := vxrm
270  vialu.io.in.bits.srcType(0) := decoder.io.out.srcType1
271  vialu.io.in.bits.srcType(1) := decoder.io.out.srcType2
272  vialu.io.in.bits.vdType := decoder.io.out.vdType
273  vialu.io.in.bits.vs1 := vs1
274  vialu.io.in.bits.vs2 := vs2
275  vialu.io.in.bits.old_vd := in.src(2)
276  vialu.io.in.bits.mask := mask
277
278  val vdOut = vialu.io.out.bits.vd
279  val vxsatOut = vialu.io.out.bits.vxsat
280
281  vialu.io.in.valid := io.in.valid
282
283// connect io
284  io.out.bits.data := vdOut
285  io.out.bits.uop := DontCare
286  vxsat := vxsatOut
287  io.out.valid := vialu.io.out.valid
288  io.in.ready := DontCare
289}
290
291object VecImmExtractor {
292  def Imm_OPIVIS(imm: UInt): UInt = {
293    SignExt(imm(4,0), 8)
294  }
295  def Imm_OPIVIU(imm: UInt): UInt = {
296    ZeroExt(imm(4,0), 8)
297  }
298
299  def imm_sew(sew: UInt, imm: UInt): UInt = {
300    val _imm = SignExt(imm(7,0), 64)
301    LookupTree(sew(1,0), List(
302      "b00".U -> VecInit(Seq.fill(8)(_imm(7,0))).asUInt,
303      "b01".U -> VecInit(Seq.fill(4)(_imm(15,0))).asUInt,
304      "b10".U -> VecInit(Seq.fill(2)(_imm(31,0))).asUInt,
305      "b11".U -> _imm(63,0),
306    ))
307  }
308
309  def apply(immType: UInt, sew: UInt, imm: UInt): UInt = {
310    val _imm = Mux(immType === SelImm.IMM_OPIVIS, Imm_OPIVIS(imm), Imm_OPIVIU(imm))
311    imm_sew(sew, _imm(7,0))
312  }
313}
314