xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VIPU.scala (revision 1e160ed8f7c18a21fbe4b2b074150af3df0aeb09)
1/****************************************************************************************
2 * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3 * Copyright (c) 2020-2021 Peng Cheng Laboratory
4 *
5 * XiangShan is licensed under Mulan PSL v2.
6 * You can use this software according to the terms and conditions of the Mulan PSL v2.
7 * You may obtain a copy of Mulan PSL v2 at:
8 *          http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 *
14 * See the Mulan PSL v2 for more details.
15 ****************************************************************************************
16 */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3._
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.alu.{VAluOpcode, VIAlu}
27import yunsuan.{VectorElementFormat, VipuType}
28import xiangshan.{SelImm, SrcType, UopDivType, XSCoreParamsKey, XSModule}
29
30import scala.collection.Seq
31
32class VIPU(implicit p: Parameters) extends VPUSubModule(p(XSCoreParamsKey).VLEN) {
33  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VipuType.dummy, "VIPU OpType not supported")
34
35// extra io
36  val vxrm = IO(Input(UInt(2.W)))
37  val vxsat = IO(Output(UInt(1.W)))
38
39// def some signal
40  val dataReg = Reg(io.out.bits.data.cloneType)
41  val dataWire = Wire(dataReg.cloneType)
42  val s_idle :: s_compute :: s_finish :: Nil = Enum(3)
43  val state = RegInit(s_idle)
44  val vialuWp = Module(new VIAluWrapper)
45  val outValid = vialuWp.io.out.valid
46  val outFire = vialuWp.io.out.fire()
47
48// reg input signal
49  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
50  val inHs = io.in.fire()
51  when(inHs && state === s_idle){
52    s0_uopReg := io.in.bits.uop
53  }
54  dataReg := Mux(outValid, dataWire, dataReg)
55
56// fsm
57  switch (state) {
58    is (s_idle) {
59      state := Mux(inHs, s_compute, s_idle)
60    }
61    is (s_compute) {
62      state := Mux(outValid, Mux(outFire, s_idle, s_finish),
63                             s_compute)
64    }
65    is (s_finish) {
66      state := Mux(io.out.fire(), s_idle, s_finish)
67    }
68  }
69
70// connect VIAlu
71  dataWire := vialuWp.io.out.bits.data
72  vialuWp.io.in.bits <> io.in.bits
73  vialuWp.io.redirectIn := DontCare  // TODO :
74  vialuWp.vxrm := vxrm
75  vialuWp.vstart := vstart
76  io.out.bits.data :=  Mux(state === s_compute && outFire, dataWire, dataReg)
77  io.out.bits.uop := s0_uopReg
78  vxsat := vialuWp.vxsat
79
80  vialuWp.io.in.valid := io.in.valid && state === s_idle
81  io.out.valid := state === s_compute && outValid || state === s_finish
82  vialuWp.io.out.ready := io.out.ready
83  io.in.ready := state === s_idle
84}
85
86class VIAluDecodeResultBundle extends Bundle {
87  val opcode = UInt(6.W)
88  val srcType2 = UInt(4.W)
89  val srcType1 = UInt(4.W)
90  val vdType = UInt(4.W)
91}
92
93class VIAluDecoder (implicit p: Parameters) extends XSModule {
94  val io = IO(new Bundle{
95    val in = Input(new Bundle{
96      val fuOpType = UInt(8.W)
97      val sew = UInt(2.W)
98    })
99    val out = Output(new VIAluDecodeResultBundle)
100  })
101
102  // u 00 s 01 f 10 mask 1111
103  val uSew = Cat(0.U(2.W), io.in.sew)
104  val uSew2 = Cat(0.U(2.W), (io.in.sew+1.U))
105  val uSewf2 = Cat(0.U(2.W), (io.in.sew-1.U))
106  val uSewf4 = Cat(0.U(2.W), (io.in.sew-2.U))
107  val uSewf8 = Cat(0.U(2.W), (io.in.sew-3.U))
108  val sSew = Cat(1.U(2.W), io.in.sew)
109  val sSew2 = Cat(1.U(2.W), (io.in.sew+1.U))
110  val sSewf2 = Cat(1.U(2.W), (io.in.sew - 1.U))
111  val sSewf4 = Cat(1.U(2.W), (io.in.sew - 2.U))
112  val sSewf8 = Cat(1.U(2.W), (io.in.sew - 3.U))
113  val mask = "b1111".U(4.W)
114
115  val out = LookupTree(io.in.fuOpType, List(
116    // --------------------- opcode       srcType2 1 vdType
117    VipuType.vadd_vv -> Cat(VAluOpcode.vadd, uSew, uSew, uSew).asUInt(),
118    VipuType.vsub_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew).asUInt(),
119    VipuType.vrsub_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew).asUInt(),
120
121    VipuType.vwaddu_vv -> Cat(VAluOpcode.vadd, uSew, uSew, uSew2).asUInt(),
122    VipuType.vwsubu_vv -> Cat(VAluOpcode.vsub, uSew, uSew, uSew2).asUInt(),
123    VipuType.vwadd_vv -> Cat(VAluOpcode.vadd, sSew, sSew, sSew2).asUInt(),
124    VipuType.vwsub_vv -> Cat(VAluOpcode.vsub, sSew, sSew, sSew2).asUInt(),
125    VipuType.vwaddu_wv -> Cat(VAluOpcode.vadd, uSew2, uSew, uSew2).asUInt(),
126    VipuType.vwsubu_wv -> Cat(VAluOpcode.vsub, uSew2, uSew, uSew2).asUInt(),
127    VipuType.vwadd_wv -> Cat(VAluOpcode.vadd, sSew2, sSew, sSew2).asUInt(),
128    VipuType.vwsub_wv -> Cat(VAluOpcode.vsub, sSew2, sSew, sSew2).asUInt(),
129
130    VipuType.vzext_vf2 -> Cat(VAluOpcode.vext, uSewf2, uSewf2, uSew).asUInt(),
131    VipuType.vsext_vf2 -> Cat(VAluOpcode.vext, sSewf2, sSewf2, sSew).asUInt(),
132    VipuType.vzext_vf4 -> Cat(VAluOpcode.vext, uSewf4, uSewf4, uSew).asUInt(),
133    VipuType.vsext_vf4 -> Cat(VAluOpcode.vext, sSewf4, sSewf4, sSew).asUInt(),
134    VipuType.vzext_vf8 -> Cat(VAluOpcode.vext, uSewf8, uSewf8, uSew).asUInt(),
135    VipuType.vsext_vf8 -> Cat(VAluOpcode.vext, sSewf8, sSewf8, sSew).asUInt(),
136
137    VipuType.vadc_vvm -> Cat(VAluOpcode.vadc, uSew, uSew, uSew).asUInt(),
138    VipuType.vmadc_vvm -> Cat(VAluOpcode.vmadc, uSew, uSew, mask).asUInt(),
139    VipuType.vmadc_vv -> Cat(VAluOpcode.vmadc, uSew, uSew, mask).asUInt(),
140
141    VipuType.vsbc_vvm -> Cat(VAluOpcode.vsbc, uSew, uSew, uSew).asUInt(),
142    VipuType.vmsbc_vvm -> Cat(VAluOpcode.vmsbc, uSew, uSew, mask).asUInt(),
143    VipuType.vmsbc_vv -> Cat(VAluOpcode.vmsbc, uSew, uSew, mask).asUInt(),
144
145    VipuType.vand_vv -> Cat(VAluOpcode.vand, uSew, uSew, uSew).asUInt(),
146    VipuType.vor_vv -> Cat(VAluOpcode.vor, uSew, uSew, uSew).asUInt(),
147    VipuType.vxor_vv -> Cat(VAluOpcode.vxor, uSew, uSew, uSew).asUInt(),
148
149    VipuType.vsll_vv -> Cat(VAluOpcode.vsll, uSew, uSew, uSew).asUInt(),
150    VipuType.vsrl_vv -> Cat(VAluOpcode.vsrl, uSew, uSew, uSew).asUInt(),
151    VipuType.vsra_vv -> Cat(VAluOpcode.vsra, uSew, uSew, uSew).asUInt(),
152
153    VipuType.vnsrl_wv -> Cat(VAluOpcode.vsrl, uSew2, uSew, uSew).asUInt(),
154    VipuType.vnsra_wv -> Cat(VAluOpcode.vsra, uSew2, uSew, uSew).asUInt(),
155
156    VipuType.vmseq_vv -> Cat(VAluOpcode.vmseq, uSew, uSew, mask).asUInt(),
157    VipuType.vmsne_vv -> Cat(VAluOpcode.vmsne, uSew, uSew, mask).asUInt(),
158    VipuType.vmsltu_vv -> Cat(VAluOpcode.vmslt, uSew, uSew, mask).asUInt(),
159    VipuType.vmslt_vv -> Cat(VAluOpcode.vmslt, sSew, sSew, mask).asUInt(),
160    VipuType.vmsleu_vv -> Cat(VAluOpcode.vmsle, uSew, uSew, mask).asUInt(),
161    VipuType.vmsle_vv -> Cat(VAluOpcode.vmsle, sSew, sSew, mask).asUInt(),
162    VipuType.vmsgtu_vv -> Cat(VAluOpcode.vmsgt, uSew, uSew, mask).asUInt(),
163    VipuType.vmsgt_vv -> Cat(VAluOpcode.vmsgt, sSew, sSew, mask).asUInt(),
164
165    VipuType.vminu_vv -> Cat(VAluOpcode.vmin, uSew, uSew, uSew).asUInt(),
166    VipuType.vmin_vv -> Cat(VAluOpcode.vmin, sSew, sSew, sSew).asUInt(),
167    VipuType.vmaxu_vv -> Cat(VAluOpcode.vmax, uSew, uSew, uSew).asUInt(),
168    VipuType.vmax_vv -> Cat(VAluOpcode.vmax, sSew, sSew, sSew).asUInt(),
169
170    VipuType.vmerge_vvm -> Cat(VAluOpcode.vmerge, uSew, uSew, mask).asUInt(),
171
172    VipuType.vmv_v_v -> Cat(VAluOpcode.vmv, uSew, uSew, uSew).asUInt(),
173
174    VipuType.vsaddu_vv -> Cat(VAluOpcode.vsadd, uSew, uSew, uSew).asUInt(),
175    VipuType.vsadd_vv -> Cat(VAluOpcode.vsadd, sSew, sSew, sSew).asUInt(),
176    VipuType.vssubu_vv -> Cat(VAluOpcode.vssub, uSew, uSew, uSew).asUInt(),
177    VipuType.vssub_vv -> Cat(VAluOpcode.vssub, sSew, sSew, sSew).asUInt(),
178
179    VipuType.vaaddu_vv -> Cat(VAluOpcode.vaadd, uSew, uSew, uSew).asUInt(),
180    VipuType.vaadd_vv -> Cat(VAluOpcode.vaadd, sSew, sSew, sSew).asUInt(),
181    VipuType.vasubu_vv -> Cat(VAluOpcode.vasub, uSew, uSew, uSew).asUInt(),
182    VipuType.vasub_vv -> Cat(VAluOpcode.vasub, sSew, sSew, sSew).asUInt(),
183
184    VipuType.vssrl_vv -> Cat(VAluOpcode.vssrl, uSew, uSew, uSew).asUInt(),
185    VipuType.vssra_vv -> Cat(VAluOpcode.vssra, uSew, uSew, uSew).asUInt(),
186
187    VipuType.vnclipu_wv -> Cat(VAluOpcode.vssrl, uSew2, uSew, uSew).asUInt(),
188    VipuType.vnclip_wv -> Cat(VAluOpcode.vssra, uSew2, uSew, uSew).asUInt(),
189
190    VipuType.vredsum_vs -> Cat(VAluOpcode.vredsum, uSew, uSew, uSew).asUInt(),
191    VipuType.vredmaxu_vs -> Cat(VAluOpcode.vredmax, uSew, uSew, uSew).asUInt(),
192    VipuType.vredmax_vs -> Cat(VAluOpcode.vredmax, sSew, sSew, sSew).asUInt(),
193    VipuType.vredminu_vs -> Cat(VAluOpcode.vredmin, uSew, uSew, uSew).asUInt(),
194    VipuType.vredmin_vs -> Cat(VAluOpcode.vredmin, sSew, sSew, sSew).asUInt(),
195    VipuType.vredand_vs -> Cat(VAluOpcode.vredand, uSew, uSew, uSew).asUInt(),
196    VipuType.vredor_vs -> Cat(VAluOpcode.vredor, uSew, uSew, uSew).asUInt(),
197    VipuType.vredxor_vs -> Cat(VAluOpcode.vredxor, uSew, uSew, uSew).asUInt(),
198
199    VipuType.vwredsumu_vs -> Cat(VAluOpcode.vredsum, uSew, uSew, uSew2).asUInt(),
200    VipuType.vwredsum_vs -> Cat(VAluOpcode.vredsum, sSew, sSew, sSew2).asUInt(),
201
202    VipuType.vmand_mm -> Cat(VAluOpcode.vand, mask, mask, mask).asUInt(),
203    VipuType.vmnand_mm -> Cat(VAluOpcode.vnand, mask, mask, mask).asUInt(),
204    VipuType.vmandn_mm -> Cat(VAluOpcode.vandn, mask, mask, mask).asUInt(),
205    VipuType.vmxor_mm -> Cat(VAluOpcode.vxor, mask, mask, mask).asUInt(),
206    VipuType.vmor_mm -> Cat(VAluOpcode.vor, mask, mask, mask).asUInt(),
207    VipuType.vmnor_mm -> Cat(VAluOpcode.vnor, mask, mask, mask).asUInt(),
208    VipuType.vmorn_mm -> Cat(VAluOpcode.vorn, mask, mask, mask).asUInt(),
209    VipuType.vmxnor_mm -> Cat(VAluOpcode.vxnor, mask, mask, mask).asUInt(),
210
211    VipuType.vcpop_m -> Cat(VAluOpcode.vcpop, mask, mask, mask).asUInt(),
212    VipuType.vfirst_m -> Cat(VAluOpcode.vfirst, mask, mask, mask).asUInt(),
213    VipuType.vmsbf_m -> Cat(VAluOpcode.vmsbf, mask, mask, mask).asUInt(),
214    VipuType.vmsif_m -> Cat(VAluOpcode.vmsif, mask, mask, mask).asUInt(),
215    VipuType.vmsof_m -> Cat(VAluOpcode.vmsof, mask, mask, mask).asUInt(),
216
217    VipuType.viota_m -> Cat(VAluOpcode.viota, mask, mask, uSew).asUInt(),
218    VipuType.vid_v -> Cat(VAluOpcode.vid, uSew, uSew, uSew).asUInt(),
219
220  )).asTypeOf(new VIAluDecodeResultBundle)
221
222  io.out <> out
223}
224
225class VIAluWrapper(implicit p: Parameters)  extends VPUSubModule(p(XSCoreParamsKey).VLEN) {
226  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VipuType.dummy, "VIPU OpType not supported")
227
228// extra io
229  val vxrm = IO(Input(UInt(2.W)))
230  val vxsat = IO(Output(UInt(1.W)))
231
232// rename signal
233  val in = io.in.bits
234  val ctrl = in.uop.ctrl
235  val vtype = ctrl.vconfig.vtype
236
237// generate src1 and src2
238  val imm = VecInit(Seq.fill(VLEN/XLEN)(VecImmExtractor(ctrl.selImm, vtype.vsew, ctrl.imm))).asUInt
239  val _vs1 = Mux(SrcType.isImm(ctrl.srcType(0)), imm, Mux(ctrl.uopDivType === UopDivType.VEC_MV_LMUL || ctrl.uopDivType === UopDivType.VEC_MV_WIDE || ctrl.uopDivType === UopDivType.VEC_MV_WIDE0 || ctrl.uopDivType === UopDivType.VEC_MV_NARROW || ctrl.uopDivType === UopDivType.VEC_MV_MASK, VecExtractor(vtype.vsew, io.in.bits.src(0)), io.in.bits.src(0)))
240  val _vs2 = in.src(1)
241  val vs1 = Mux(VipuType.needReverse(ctrl.fuOpType), _vs2, _vs1)
242  val vs2 = Mux(VipuType.needReverse(ctrl.fuOpType), _vs1, _vs2)
243  val mask = Mux(VipuType.needClearMask(ctrl.fuOpType), 0.U, in.src(3))
244
245// connect VIAlu
246  val decoder = Module(new VIAluDecoder)
247  val vialu = Module(new VIAlu)
248  decoder.io.in.fuOpType := in.uop.ctrl.fuOpType
249  decoder.io.in.sew := in.uop.ctrl.vconfig.vtype.vsew(1,0)
250
251  vialu.io.in.bits.opcode := decoder.io.out.opcode
252  vialu.io.in.bits.info.vm := in.uop.ctrl.vm
253  vialu.io.in.bits.info.ma := in.uop.ctrl.vconfig.vtype.vma
254  vialu.io.in.bits.info.ta := in.uop.ctrl.vconfig.vtype.vta
255  vialu.io.in.bits.info.vlmul := in.uop.ctrl.vconfig.vtype.vlmul
256  vialu.io.in.bits.info.vl := in.uop.ctrl.vconfig.vl
257
258  vialu.io.in.bits.info.vstart := vstart // TODO :
259  vialu.io.in.bits.info.uopIdx := in.uop.ctrl.uopIdx
260
261  vialu.io.in.bits.info.vxrm := vxrm
262  vialu.io.in.bits.srcType(0) := decoder.io.out.srcType2
263  vialu.io.in.bits.srcType(1) := decoder.io.out.srcType1
264  vialu.io.in.bits.vdType := decoder.io.out.vdType
265  vialu.io.in.bits.vs1 := vs1
266  vialu.io.in.bits.vs2 := vs2
267  vialu.io.in.bits.old_vd := in.src(2)
268  vialu.io.in.bits.mask := mask
269
270  val vdOut = vialu.io.out.bits.vd
271  val vxsatOut = vialu.io.out.bits.vxsat
272
273  vialu.io.in.valid := io.in.valid
274
275// connect io
276  io.out.bits.data := vdOut
277  io.out.bits.uop := DontCare
278  vxsat := vxsatOut
279  io.out.valid := vialu.io.out.valid
280  io.in.ready := DontCare
281}
282
283object VecImmExtractor {
284  def Imm_OPIVIS(imm: UInt): UInt = {
285    SignExt(imm(4,0), 8)
286  }
287  def Imm_OPIVIU(imm: UInt): UInt = {
288    ZeroExt(imm(4,0), 8)
289  }
290
291  def imm_sew(sew: UInt, imm: UInt): UInt = {
292    val _imm = SignExt(imm(7,0), 64)
293    LookupTree(sew(1,0), List(
294      "b00".U -> VecInit(Seq.fill(8)(_imm(7,0))).asUInt,
295      "b01".U -> VecInit(Seq.fill(4)(_imm(15,0))).asUInt,
296      "b10".U -> VecInit(Seq.fill(2)(_imm(31,0))).asUInt,
297      "b11".U -> _imm(63,0),
298    ))
299  }
300
301  def apply(immType: UInt, sew: UInt, imm: UInt): UInt = {
302    val _imm = Mux(immType === SelImm.IMM_OPIVIS, Imm_OPIVIS(imm), Imm_OPIVIU(imm))
303    imm_sew(sew, _imm(7,0))
304  }
305}
306