xref: /XiangShan/src/main/scala/device/MemEncryptUtil.scala (revision 30f35717e23156cb95b30a36db530384545b48a4)
1/***************************************************************************************
2* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences
3*
4* XiangShan is licensed under Mulan PSL v2.
5* You can use this software according to the terms and conditions of the Mulan PSL v2.
6* You may obtain a copy of Mulan PSL v2 at:
7*          http://license.coscl.org.cn/MulanPSL2
8*
9* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
10* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
11* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
12*
13* See the Mulan PSL v2 for more details.
14***************************************************************************************/
15
16package device
17
18import chisel3._
19import chisel3.util._
20import chisel3.util.HasBlackBoxResource
21import org.chipsalliance.cde.config.Parameters
22import freechips.rocketchip.amba.axi4._
23import freechips.rocketchip.diplomacy._
24import freechips.rocketchip.util._
25import freechips.rocketchip.amba.apb._
26import freechips.rocketchip.tilelink.AXI4TLState
27
28class IrrevocableQueue[T <: Data](gen: T, entries: Int, flow: Boolean = false) extends Module {
29  val io = IO(new Bundle {
30    val enq = Flipped(Irrevocable(gen))
31    val deq = Irrevocable(gen)
32  })
33  val queue = Module(new Queue(gen, entries = entries, flow = flow))
34
35  queue.io.enq.valid := io.enq.valid
36  queue.io.enq.bits  := io.enq.bits
37  io.enq.ready := queue.io.enq.ready
38
39  io.deq.valid := queue.io.deq.valid
40  io.deq.bits  := queue.io.deq.bits
41  queue.io.deq.ready := io.deq.ready
42}
43
44class IrrevocableArbiter[T <: Data](gen: T, n: Int) extends Module {
45  val io = IO(new Bundle {
46    val in = Flipped(Vec(n, Irrevocable(gen)))
47    val out = Irrevocable(gen)
48  })
49
50  val decoupledIn = io.in.map { irrevocable =>
51    val decoupled = Wire(Decoupled(gen))
52    decoupled.valid := irrevocable.valid
53    decoupled.bits  := irrevocable.bits
54    irrevocable.ready := decoupled.ready
55    decoupled
56  }
57
58  val arbiter = Module(new Arbiter(gen, n))
59  arbiter.io.in <> decoupledIn
60
61  io.out.valid := arbiter.io.out.valid
62  io.out.bits  := arbiter.io.out.bits
63  arbiter.io.out.ready := io.out.ready
64}
65
66// CKI (Cipher Key Input) is a constant input used in the SM4 encryption algorithm.
67// It is part of the key expansion process and participates in generating subkeys.
68// During each round of the key expansion, the CKI value is mixed with other constants and
69// the initial key to enhance the security of the encryption algorithm.
70class GetCKI extends Module {
71  val io = IO(new Bundle {
72    val countRoundIn = Input(UInt(5.W))
73    val ckiOut = Output(UInt(32.W))
74  })
75  val ckiOutReg= RegInit(0.U(32.W))
76  // 32 32-bit CKI constant values
77  val ckiValuesVec = VecInit(Seq(
78    "h00070e15".U, "h1c232a31".U, "h383f464d".U, "h545b6269".U,
79    "h70777e85".U, "h8c939aa1".U, "ha8afb6bd".U, "hc4cbd2d9".U,
80    "he0e7eef5".U, "hfc030a11".U, "h181f262d".U, "h343b4249".U,
81    "h50575e65".U, "h6c737a81".U, "h888f969d".U, "ha4abb2b9".U,
82    "hc0c7ced5".U, "hdce3eaf1".U, "hf8ff060d".U, "h141b2229".U,
83    "h30373e45".U, "h4c535a61".U, "h686f767d".U, "h848b9299".U,
84    "ha0a7aeb5".U, "hbcc3cad1".U, "hd8dfe6ed".U, "hf4fb0209".U,
85    "h10171e25".U, "h2c333a41".U, "h484f565d".U, "h646b7279".U
86  ))
87  when(io.countRoundIn < 32.U) {
88    ckiOutReg := ckiValuesVec(io.countRoundIn)
89  }.otherwise {
90    ckiOutReg := 0.U
91  }
92  io.ckiOut := ckiOutReg
93}
94
95
96// S-box is used in SM4 for nonlinear transformations during encryption processes.
97// SM4 uses a fixed 256 byte S-box for byte replacement.
98// This replacement process is achieved by replacing the input 8-bit data
99// with the corresponding values in the S-box lookup table.
100class SboxReplace extends Module {
101  val io = IO(new Bundle {
102    val dataIn = Input(UInt(8.W))
103    val resultOut = Output(UInt(8.W))
104  })
105  // A 256 element S-box lookup table, where each element is an 8-bit hexadecimal constant
106  val sbox = VecInit(Seq(
107    0xd6.U, 0x90.U, 0xe9.U, 0xfe.U, 0xcc.U, 0xe1.U, 0x3d.U, 0xb7.U, 0x16.U, 0xb6.U, 0x14.U, 0xc2.U, 0x28.U, 0xfb.U, 0x2c.U, 0x05.U,
108    0x2b.U, 0x67.U, 0x9a.U, 0x76.U, 0x2a.U, 0xbe.U, 0x04.U, 0xc3.U, 0xaa.U, 0x44.U, 0x13.U, 0x26.U, 0x49.U, 0x86.U, 0x06.U, 0x99.U,
109    0x9c.U, 0x42.U, 0x50.U, 0xf4.U, 0x91.U, 0xef.U, 0x98.U, 0x7a.U, 0x33.U, 0x54.U, 0x0b.U, 0x43.U, 0xed.U, 0xcf.U, 0xac.U, 0x62.U,
110    0xe4.U, 0xb3.U, 0x1c.U, 0xa9.U, 0xc9.U, 0x08.U, 0xe8.U, 0x95.U, 0x80.U, 0xdf.U, 0x94.U, 0xfa.U, 0x75.U, 0x8f.U, 0x3f.U, 0xa6.U,
111    0x47.U, 0x07.U, 0xa7.U, 0xfc.U, 0xf3.U, 0x73.U, 0x17.U, 0xba.U, 0x83.U, 0x59.U, 0x3c.U, 0x19.U, 0xe6.U, 0x85.U, 0x4f.U, 0xa8.U,
112    0x68.U, 0x6b.U, 0x81.U, 0xb2.U, 0x71.U, 0x64.U, 0xda.U, 0x8b.U, 0xf8.U, 0xeb.U, 0x0f.U, 0x4b.U, 0x70.U, 0x56.U, 0x9d.U, 0x35.U,
113    0x1e.U, 0x24.U, 0x0e.U, 0x5e.U, 0x63.U, 0x58.U, 0xd1.U, 0xa2.U, 0x25.U, 0x22.U, 0x7c.U, 0x3b.U, 0x01.U, 0x21.U, 0x78.U, 0x87.U,
114    0xd4.U, 0x00.U, 0x46.U, 0x57.U, 0x9f.U, 0xd3.U, 0x27.U, 0x52.U, 0x4c.U, 0x36.U, 0x02.U, 0xe7.U, 0xa0.U, 0xc4.U, 0xc8.U, 0x9e.U,
115    0xea.U, 0xbf.U, 0x8a.U, 0xd2.U, 0x40.U, 0xc7.U, 0x38.U, 0xb5.U, 0xa3.U, 0xf7.U, 0xf2.U, 0xce.U, 0xf9.U, 0x61.U, 0x15.U, 0xa1.U,
116    0xe0.U, 0xae.U, 0x5d.U, 0xa4.U, 0x9b.U, 0x34.U, 0x1a.U, 0x55.U, 0xad.U, 0x93.U, 0x32.U, 0x30.U, 0xf5.U, 0x8c.U, 0xb1.U, 0xe3.U,
117    0x1d.U, 0xf6.U, 0xe2.U, 0x2e.U, 0x82.U, 0x66.U, 0xca.U, 0x60.U, 0xc0.U, 0x29.U, 0x23.U, 0xab.U, 0x0d.U, 0x53.U, 0x4e.U, 0x6f.U,
118    0xd5.U, 0xdb.U, 0x37.U, 0x45.U, 0xde.U, 0xfd.U, 0x8e.U, 0x2f.U, 0x03.U, 0xff.U, 0x6a.U, 0x72.U, 0x6d.U, 0x6c.U, 0x5b.U, 0x51.U,
119    0x8d.U, 0x1b.U, 0xaf.U, 0x92.U, 0xbb.U, 0xdd.U, 0xbc.U, 0x7f.U, 0x11.U, 0xd9.U, 0x5c.U, 0x41.U, 0x1f.U, 0x10.U, 0x5a.U, 0xd8.U,
120    0x0a.U, 0xc1.U, 0x31.U, 0x88.U, 0xa5.U, 0xcd.U, 0x7b.U, 0xbd.U, 0x2d.U, 0x74.U, 0xd0.U, 0x12.U, 0xb8.U, 0xe5.U, 0xb4.U, 0xb0.U,
121    0x89.U, 0x69.U, 0x97.U, 0x4a.U, 0x0c.U, 0x96.U, 0x77.U, 0x7e.U, 0x65.U, 0xb9.U, 0xf1.U, 0x09.U, 0xc5.U, 0x6e.U, 0xc6.U, 0x84.U,
122    0x18.U, 0xf0.U, 0x7d.U, 0xec.U, 0x3a.U, 0xdc.U, 0x4d.U, 0x20.U, 0x79.U, 0xee.U, 0x5f.U, 0x3e.U, 0xd7.U, 0xcb.U, 0x39.U, 0x48.U
123  ))
124
125  io.resultOut := sbox(io.dataIn)
126}
127
128// Nonlinear Transformation in Data Encryption Process
129class TransformForEncDec extends Module {
130  val io = IO(new Bundle {
131    val data_in = Input(UInt(32.W))
132    val result_out = Output(UInt(32.W))
133  })
134
135  val bytes_in = VecInit(Seq(io.data_in(7, 0), io.data_in(15, 8), io.data_in(23, 16), io.data_in(31, 24)))
136  val bytes_replaced = Wire(Vec(4, UInt(8.W)))
137  val word_replaced = Wire(UInt(32.W))
138
139  val sbox_replace_modules = VecInit(Seq.fill(4)(Module(new SboxReplace).io))
140  for (i <- 0 until 4) {
141    sbox_replace_modules(i).dataIn := bytes_in(i)
142    bytes_replaced(i) := sbox_replace_modules(i).resultOut
143  }
144
145  word_replaced := Cat(bytes_replaced.reverse)
146
147  io.result_out := ((word_replaced ^ Cat(word_replaced(29, 0), word_replaced(31, 30))) ^
148                      (Cat(word_replaced(21, 0), word_replaced(31, 22)) ^ Cat(word_replaced(13, 0), word_replaced(31, 14)))) ^
149                     Cat(word_replaced(7, 0), word_replaced(31, 8))
150}
151
152
153// Nonlinear Transformation in Key Expansion Process
154class TransformForKeyExp extends Module {
155  val io = IO(new Bundle {
156    val data_in = Input(UInt(32.W))
157    val data_after_linear_key_out = Output(UInt(32.W))
158  })
159  val bytes_in = VecInit(Seq(io.data_in(7, 0), io.data_in(15, 8), io.data_in(23, 16), io.data_in(31, 24)))
160  val bytes_replaced = Wire(Vec(4, UInt(8.W)))
161  val word_replaced = Wire(UInt(32.W))
162
163  val sbox_replace_modules = VecInit(Seq.fill(4)(Module(new SboxReplace).io))
164  for (i <- 0 until 4) {
165    sbox_replace_modules(i).dataIn := bytes_in(i)
166    bytes_replaced(i) := sbox_replace_modules(i).resultOut
167  }
168
169  word_replaced := Cat(bytes_replaced.reverse)
170
171  io.data_after_linear_key_out := (word_replaced ^ Cat(word_replaced(18, 0), word_replaced(31, 19))) ^ Cat(word_replaced(8, 0), word_replaced(31, 9))
172}
173
174// The key expansion algorithm requires a total of 32 rounds of operations, including one round of operation
175class OneRoundForKeyExp extends Module {
176  val io = IO(new Bundle {
177    val countRoundIn = Input(UInt(5.W))
178    val dataIn = Input(UInt(128.W))
179    val ckParameterIn = Input(UInt(32.W))
180    val resultOut = Output(UInt(128.W))
181  })
182  // In key expansion, the first step is to XOR each word of the original key with the system parameter to obtain four new words.
183  // system parameter: FK0, FK1, FK2, FK3.
184  val FK0 = "ha3b1bac6".U
185  val FK1 = "h56aa3350".U
186  val FK2 = "h677d9197".U
187  val FK3 = "hb27022dc".U
188
189  val word = VecInit(Seq(io.dataIn(127, 96), io.dataIn(95, 64), io.dataIn(63, 32), io.dataIn(31, 0)))
190
191  val k0 = word(0) ^ FK0
192  val k1 = word(1) ^ FK1
193  val k2 = word(2) ^ FK2
194  val k3 = word(3) ^ FK3
195
196
197  val dataForXor = io.ckParameterIn
198  val tmp0 = Mux(io.countRoundIn === 0.U, k1 ^ k2, word(1) ^ word(2))
199  val tmp1 = Mux(io.countRoundIn === 0.U, k3 ^ dataForXor, word(3) ^ dataForXor)
200  val dataForTransform = tmp0 ^ tmp1
201
202  val transformKey = Module(new TransformForKeyExp)
203  transformKey.io.data_in := dataForTransform
204
205  io.resultOut := Mux(io.countRoundIn === 0.U,
206                      Cat(k1, k2, k3, transformKey.io.data_after_linear_key_out ^ k0),
207                      Cat(word(1), word(2), word(3), transformKey.io.data_after_linear_key_out ^ word(0)))
208}
209
210// The SM4 encryption algorithm requires a total of 32 rounds of operations, including one round of operation
211class OneRoundForEncDec  extends Module {
212  val io = IO(new Bundle {
213    val data_in = Input(UInt(128.W))
214    val round_key_in = Input(UInt(32.W))
215    val result_out = Output(UInt(128.W))
216  })
217
218  val word = VecInit(Seq(io.data_in(127, 96), io.data_in(95, 64), io.data_in(63, 32), io.data_in(31, 0)))
219
220  val tmp0 = word(1) ^ word(2)
221  val tmp1 = word(3) ^ io.round_key_in
222  val data_for_transform = tmp0 ^ tmp1
223
224  val transform_encdec = Module(new TransformForEncDec)
225  transform_encdec.io.data_in := data_for_transform
226
227  io.result_out := Cat(word(1), word(2), word(3), transform_encdec.io.result_out ^ word(0))
228}
229
230
231
232class AXI4BundleWWithoutData(params: AXI4BundleParameters) extends Bundle {
233  val strb = UInt((params.dataBits/8).W)
234  val last = Bool()
235  val user = BundleMap(params.requestFields.filter(_.key.isData))
236}
237
238class AXI4BundleRWithoutData(params: AXI4BundleParameters) extends Bundle {
239  val id   = UInt(params.idBits.W)
240  val resp = UInt(params.respBits.W)
241  val user = BundleMap(params.responseFields)
242  val echo = BundleMap(params.echoFields)
243  val last = Bool()
244}
245
246// OnePipeEncBase is an abstract class that defines the structure of a single-pipe encryption module.
247// The main purpose of this class is to standardize the input and output interfaces for encryption modules.
248abstract class OnePipeEncBase(implicit p: Parameters) extends MemEncryptModule {
249  val io = IO(new Bundle {
250    val onepipe_in = new Bundle {
251      val keyid = Input(UInt(KeyIDBits.W))
252      val data_in = Input(UInt(128.W))
253      val tweak_in = Input(UInt(128.W))
254      val axi4_other = Input(new AXI4BundleWWithoutData(MemcedgeIn.bundle))
255      val round_key_in = Input(Vec(32/MemencPipes, UInt(32.W)))
256    }
257    val onepipe_out = new Bundle {
258      val result_out = Output(UInt(128.W))
259      val axi4_other_out = Output(new AXI4BundleWWithoutData(MemcedgeIn.bundle))
260      val tweak_out = Output(UInt(128.W))
261      val keyid_out = Output(UInt(5.W))
262    }
263  })
264}
265
266// The OnePipeForEnc module needs to actually perform the encryption process for each
267// level of the pipeline in the encryption pipeline.
268// The flow level can be customized and configured.
269class OnePipeForEnc(implicit p: Parameters) extends OnePipeEncBase {
270
271  val OneRoundForEncDecs = Seq.fill(32/MemencPipes)(Module(new OneRoundForEncDec))
272
273  for (i <- 0 until 32/MemencPipes) {
274    val mod = OneRoundForEncDecs(i)
275    mod.io.round_key_in := io.onepipe_in.round_key_in(i)
276    if (i == 0) mod.io.data_in := io.onepipe_in.data_in else mod.io.data_in := OneRoundForEncDecs(i - 1).io.result_out
277  }
278
279  io.onepipe_out.result_out := OneRoundForEncDecs.last.io.result_out
280  io.onepipe_out.keyid_out := io.onepipe_in.keyid
281  io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other
282  io.onepipe_out.tweak_out := io.onepipe_in.tweak_in
283
284}
285// The encryption process of each stage in the encryption pipeline does not require
286// the OnePipeForEnc module to actually execute the encryption process.
287// Test usage
288class OnePipeForEncNoEnc(implicit p: Parameters) extends OnePipeEncBase {
289  io.onepipe_out.result_out := io.onepipe_in.data_in
290  io.onepipe_out.keyid_out := io.onepipe_in.keyid
291  io.onepipe_out.tweak_out := io.onepipe_in.tweak_in
292  io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other
293}
294
295abstract class OnePipeDecBase(implicit p: Parameters) extends MemEncryptModule {
296  val io = IO(new Bundle {
297    val onepipe_in = new Bundle {
298      val keyid = Input(UInt(KeyIDBits.W))
299      val data_in = Input(UInt(128.W))
300      val tweak_in = Input(UInt(128.W))
301      val axi4_other = Input(new AXI4BundleRWithoutData(MemcedgeOut.bundle))
302      val round_key_in = Input(Vec(32/MemencPipes, UInt(32.W)))
303    }
304    val onepipe_out = new Bundle {
305      val result_out = Output(UInt(128.W))
306      val axi4_other_out = Output(new AXI4BundleRWithoutData(MemcedgeOut.bundle))
307      val tweak_out = Output(UInt(128.W))
308      val keyid_out = Output(UInt(5.W))
309    }
310  })
311}
312class OnePipeForDec(implicit p: Parameters) extends OnePipeDecBase {
313
314  val OneRoundForEncDecs = Seq.fill(32/MemencPipes)(Module(new OneRoundForEncDec))
315
316  for (i <- 0 until 32/MemencPipes) {
317    val mod = OneRoundForEncDecs(i)
318    mod.io.round_key_in := io.onepipe_in.round_key_in(i)
319    if (i == 0) mod.io.data_in := io.onepipe_in.data_in else mod.io.data_in := OneRoundForEncDecs(i - 1).io.result_out
320  }
321
322  io.onepipe_out.result_out := OneRoundForEncDecs.last.io.result_out
323  io.onepipe_out.keyid_out := io.onepipe_in.keyid
324  io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other
325  io.onepipe_out.tweak_out := io.onepipe_in.tweak_in
326
327}
328class OnePipeForDecNoDec(implicit p: Parameters) extends OnePipeDecBase {
329  io.onepipe_out.result_out := io.onepipe_in.data_in
330  io.onepipe_out.keyid_out := io.onepipe_in.keyid
331  io.onepipe_out.axi4_other_out := io.onepipe_in.axi4_other
332  io.onepipe_out.tweak_out := io.onepipe_in.tweak_in
333}
334
335// Finite field operations after encrypting tweak (encryption adjustment value) in XTS confidential mode
336// Encryption adjustment value (tweak), This encryption adjustment utilizes finite fields and XOR operations to
337// ensure security by preventing the same ciphertext from being obtained even if the packet is identical each time.
338
339// Calculation process:
340// Move the logic one bit to the left. If the highest bit that is moved out is 1, XOR the lower 8 bits 0x87 three times,
341// generating four different sets of data for XOR before and after data encryption;
342class GF128 extends Module{
343  val io = IO(new Bundle {
344    val tweak_in = Input(UInt(128.W))
345    val tweak_out = Output(UInt(512.W))
346  })
347
348  val gf_128_fdbk = "h87".U(8.W)
349  val tweak_1_isgf = io.tweak_in(127)
350  val tweak_2_isgf = io.tweak_in(126)
351  val tweak_3_isgf = io.tweak_in(125)
352  val tweak_1_shifted = Wire(UInt(128.W))
353  val tweak_2_shifted = Wire(UInt(128.W))
354  val tweak_3_shifted = Wire(UInt(128.W))
355  val tweak_1_out = Wire(UInt(128.W))
356  val tweak_2_out = Wire(UInt(128.W))
357  val tweak_3_out = Wire(UInt(128.W))
358
359  tweak_1_shifted := io.tweak_in << 1
360  tweak_2_shifted := tweak_1_out << 1
361  tweak_3_shifted := tweak_2_out << 1
362
363  tweak_1_out := Mux(tweak_1_isgf, tweak_1_shifted ^ gf_128_fdbk, tweak_1_shifted)
364  tweak_2_out := Mux(tweak_2_isgf, tweak_2_shifted ^ gf_128_fdbk, tweak_2_shifted)
365  tweak_3_out := Mux(tweak_3_isgf, tweak_3_shifted ^ gf_128_fdbk, tweak_3_shifted)
366
367  io.tweak_out := Cat(tweak_3_out, tweak_2_out, tweak_1_out, io.tweak_in)
368}
369
370// Perform finite field operations on the initial tweak during the request sending process,
371// and output according to the requirements (aw. len)
372class TweakGF128(implicit p: Parameters) extends MemEncryptModule{
373    val io = IO(new Bundle {
374      val req = Flipped(DecoupledIO(new Bundle {
375        val len = UInt(MemcedgeIn.bundle.lenBits.W)
376        val addr = UInt(PAddrBits.W)
377        val tweak_in = UInt(128.W)
378      }))
379      val resp = DecoupledIO(new Bundle {
380        val tweak_out = UInt(256.W)
381        val keyid_out = UInt(KeyIDBits.W)
382        val addr_out = UInt(PAddrBits.W)
383      })
384    })
385    val tweak_gf128 = Module(new GF128())
386    tweak_gf128.io.tweak_in := io.req.bits.tweak_in
387
388    val reg_valid = RegInit(false.B)
389    val reg_counter = RegInit(0.U(2.W))
390    val reg_len = RegInit(0.U(MemcedgeIn.bundle.lenBits.W))
391    val reg_addr = RegInit(0.U(PAddrBits.W))
392    val reg_tweak_result = RegInit(0.U(512.W))
393
394    io.req.ready := !reg_valid || (reg_valid && io.resp.ready && (reg_len === 0.U || reg_counter =/= 0.U))
395
396    when(io.req.fire) {
397      reg_tweak_result := tweak_gf128.io.tweak_out
398      reg_len := io.req.bits.len
399      reg_addr := io.req.bits.addr
400      reg_valid := true.B
401      reg_counter := 0.U
402    }.elsewhen(reg_valid && io.resp.ready) {
403      when(reg_len === 0.U) {
404        reg_valid := false.B
405        reg_counter := 0.U
406      }.otherwise {
407        when(reg_counter === 0.U) {
408          reg_counter := reg_counter + 1.U
409        }.otherwise {
410          reg_valid := false.B
411          reg_counter := 0.U
412        }
413      }
414    }.otherwise {
415      reg_valid := reg_valid
416      reg_counter := reg_counter
417    }
418
419
420    io.resp.bits.addr_out  := reg_addr
421    io.resp.bits.keyid_out := reg_addr(PAddrBits - 1, PAddrBits - KeyIDBits)
422    io.resp.bits.tweak_out := Mux(reg_len === 0.U, Mux(reg_addr(5) === 0.U, reg_tweak_result(255, 0), reg_tweak_result(511, 256)),
423                                  Mux(reg_counter === 0.U, reg_tweak_result(255, 0), reg_tweak_result(511, 256)))
424    io.resp.valid          := reg_valid
425}
426
427// The encryption process in each stage of the pipeline during the initial tweak encryption process
428class OnePipeForTweakEnc(implicit p: Parameters) extends MemEncryptModule {
429    val io = IO(new Bundle {
430      val in = new Bundle {
431        val data_in = Input(UInt(128.W))
432        val addr_in = Input(UInt(PAddrBits.W))
433        val len_in = Input(UInt(MemcedgeOut.bundle.lenBits.W))
434        val id_in = Input(UInt(MemcedgeOut.bundle.idBits.W))
435        val round_key_in = Input(Vec(32/MemencPipes, UInt(32.W)))
436      }
437      val out = new Bundle {
438        val result_out = Output(UInt(128.W))
439        val addr_out = Output(UInt(PAddrBits.W))
440        val len_out = Output(UInt(MemcedgeOut.bundle.lenBits.W))
441        val id_out = Output(UInt(MemcedgeOut.bundle.idBits.W))
442      }
443  })
444
445  val OneRoundForEncDecs = Seq.fill(32/MemencPipes)(Module(new OneRoundForEncDec))
446  for (i <- 0 until 32/MemencPipes) {
447    val mod = OneRoundForEncDecs(i)
448    mod.io.round_key_in := io.in.round_key_in(i)
449    if (i == 0) mod.io.data_in := io.in.data_in else mod.io.data_in := OneRoundForEncDecs(i - 1).io.result_out
450  }
451
452  io.out.result_out := OneRoundForEncDecs.last.io.result_out
453  io.out.addr_out   := io.in.addr_in
454  io.out.len_out    := io.in.len_in
455  io.out.id_out     := io.in.id_in
456}
457
458// Initial TWEAK encryption module.
459// The pipeline configuration is determined by the MemencPipes parameter
460class TweakEncrypt(opt: Boolean)(implicit p: Parameters) extends MemEncryptModule{
461  val edgeUse = if (opt) MemcedgeIn else MemcedgeOut
462  val io = IO(new Bundle {
463    val tweak_enc_req = Flipped(DecoupledIO(new Bundle {
464      val tweak = UInt(128.W)
465      val addr_in = UInt(PAddrBits.W)
466      val len_in = UInt(edgeUse.bundle.lenBits.W)   // 6 bit
467      val id_in = UInt(edgeUse.bundle.idBits.W)
468      val tweak_round_keys = Vec(32, UInt(32.W))
469    }))
470    val tweak_enc_resp = DecoupledIO(new Bundle {
471      val tweak_encrpty = UInt(128.W)
472      val addr_out = UInt(PAddrBits.W)
473      val len_out = UInt(edgeUse.bundle.lenBits.W)
474      val id_out = UInt(edgeUse.bundle.idBits.W)
475    })
476  })
477
478  val reg_tweak = Reg(Vec(MemencPipes, UInt(128.W)))
479  val reg_addr = Reg(Vec(MemencPipes, UInt(PAddrBits.W)))
480  val reg_len = Reg(Vec(MemencPipes, UInt(edgeUse.bundle.lenBits.W)))
481  val reg_id = Reg(Vec(MemencPipes, UInt(edgeUse.bundle.idBits.W)))
482  val reg_tweak_valid = RegInit(VecInit(Seq.fill(MemencPipes)(false.B)))
483  // TWEAK encryption requires 32 rounds of encryption keys, grouped by pipeline level
484  val wire_round_key = Wire(Vec(MemencPipes, UInt((32 * 32 / MemencPipes).W)))
485
486  val keysPerPipe = 32 / MemencPipes
487  for (i <- 0 until MemencPipes) {
488    val keySegment = VecInit((0 until keysPerPipe).map(j => io.tweak_enc_req.bits.tweak_round_keys(i * keysPerPipe + j)))
489    wire_round_key(i) := Cat(keySegment.asUInt)
490  }
491
492  val wire_ready_result = WireInit(VecInit(Seq.fill(MemencPipes)(false.B)))
493  // The configuration method for each level of encryption module in tweak
494  def configureModule(i: Int, dataIn: UInt, addrIn: UInt, lenIn: UInt, idIn: UInt, roundKeys: UInt): OnePipeForTweakEnc = {
495
496    when(wire_ready_result(i) && (if (i == 0) io.tweak_enc_req.valid else reg_tweak_valid(i-1))) {
497      reg_tweak_valid(i) := true.B
498    }.elsewhen(reg_tweak_valid(i) && (if (i == MemencPipes - 1) io.tweak_enc_resp.ready else wire_ready_result(i+1))) {
499      reg_tweak_valid(i) := false.B
500    }.otherwise {
501      reg_tweak_valid(i) := reg_tweak_valid(i)
502    }
503    wire_ready_result(i) := !reg_tweak_valid(i) || (reg_tweak_valid(i) && (if (i == MemencPipes - 1) io.tweak_enc_resp.ready else wire_ready_result(i+1)))
504
505    val module = Module(new OnePipeForTweakEnc())
506    module.io.in.data_in := dataIn
507    module.io.in.addr_in := addrIn
508    module.io.in.len_in  := lenIn
509    module.io.in.id_in   := idIn
510    for (j <- 0 until 32/MemencPipes) {
511      module.io.in.round_key_in(j) := roundKeys(j * 32 + 31, j * 32)
512    }
513    when(wire_ready_result(i) && (if (i == 0) io.tweak_enc_req.valid else reg_tweak_valid(i-1))) {
514      reg_tweak(i) := module.io.out.result_out
515      reg_addr(i)  := module.io.out.addr_out
516      reg_len(i)   := module.io.out.len_out
517      reg_id(i)    := module.io.out.id_out
518    }
519    module
520  }
521  // Instantiate the tweak encryption module for each pipeline level
522  val tweak_enc_modules = (0 until MemencPipes).map { i =>
523      if (i == 0) {
524        configureModule(i, io.tweak_enc_req.bits.tweak, io.tweak_enc_req.bits.addr_in, io.tweak_enc_req.bits.len_in, io.tweak_enc_req.bits.id_in, wire_round_key(i))
525      } else {
526        configureModule(i, reg_tweak(i-1), reg_addr(i-1), reg_len(i-1), reg_id(i-1), wire_round_key(i))
527      }
528  }
529  val result_out = Cat(
530    reg_tweak.last(31, 0),
531    reg_tweak.last(63, 32),
532    reg_tweak.last(95, 64),
533    reg_tweak.last(127, 96)
534  )
535    io.tweak_enc_resp.bits.tweak_encrpty  := result_out
536    io.tweak_enc_resp.bits.addr_out       := reg_addr.last
537    io.tweak_enc_resp.bits.len_out        := reg_len.last
538    io.tweak_enc_resp.bits.id_out         := reg_id.last
539    io.tweak_enc_resp.valid           := reg_tweak_valid.last
540    io.tweak_enc_req.ready            := wire_ready_result(0)
541
542}
543
544
545// tweak table entry in AR Channel
546class TweakTableEntry(implicit val p: Parameters) extends Bundle with Memconsts {
547  val v_flag        = Bool()
548  val keyid         = UInt(KeyIDBits.W)
549  val len           = UInt(MemcedgeOut.bundle.lenBits.W)
550  val tweak_encrpty = UInt(128.W)
551  val sel_counter   = Bool()
552}
553class TweakTableModeEntry extends Bundle {
554  val dec_mode       = Bool()
555}
556// tweak table in AR Channel
557class TweakTable(implicit p: Parameters) extends MemEncryptModule {
558  val io = IO(new Bundle {
559    // Write to tweak table
560    val write = Flipped(DecoupledIO(new Bundle {
561      val id      = UInt(MemcedgeOut.bundle.idBits.W)
562      val len     = UInt(MemcedgeOut.bundle.lenBits.W)
563      val addr    = UInt(PAddrBits.W)
564      val tweak_encrpty = UInt(128.W)
565    }))
566    // Read from the tweak table with the ID of channel R in AXI4 as the index
567    val req = Flipped(DecoupledIO(new Bundle {
568      val read_id = UInt(MemcedgeOut.bundle.idBits.W)
569    }))
570    // Tweak table read response
571    val resp = DecoupledIO(new Bundle {
572      val read_tweak       = UInt(128.W)
573      val read_keyid       = UInt(KeyIDBits.W)
574      val read_sel_counter = Bool()
575    })
576    val w_mode = Flipped(DecoupledIO(new Bundle {
577      val id       = UInt(MemcedgeOut.bundle.idBits.W)
578      val dec_mode = Input(Bool())
579    }))
580    val r_mode = new Bundle {
581      val id = Input(UInt(MemcedgeOut.bundle.idBits.W))
582      val dec_mode = Output(Bool())
583    }
584  })
585
586  val init_tweak_entry = Wire(new TweakTableEntry())
587    init_tweak_entry.v_flag := false.B
588    init_tweak_entry.keyid := DontCare
589    init_tweak_entry.len := DontCare
590    init_tweak_entry.tweak_encrpty := DontCare
591    init_tweak_entry.sel_counter := DontCare
592  val init_mode_entry = Wire(new TweakTableModeEntry)
593    init_mode_entry.dec_mode := false.B
594  val tweak_table = RegInit(VecInit(Seq.fill((1 << (MemcedgeOut.bundle.idBits - 1)) + 1)(init_tweak_entry)))
595  val tweak_mode_table = RegInit(VecInit(Seq.fill((1 << (MemcedgeOut.bundle.idBits - 1)) + 1)(init_mode_entry)))
596
597  // write tweak table entry logic
598  when(io.write.valid) {
599      val write_entry = tweak_table(io.write.bits.id)
600      write_entry.tweak_encrpty    := io.write.bits.tweak_encrpty
601      write_entry.keyid            := io.write.bits.addr(PAddrBits-1, PAddrBits-KeyIDBits)
602      write_entry.len              := io.write.bits.len
603      write_entry.v_flag           := true.B
604
605      when(io.write.bits.len === 1.U) {
606          write_entry.sel_counter := false.B
607      }.otherwise {
608          write_entry.sel_counter := Mux(io.write.bits.addr(5) === 0.U, false.B, true.B)
609      }
610  }
611  io.write.ready := true.B
612
613  // write mode table entry logic
614  when(io.w_mode.valid) {
615    val write_mode_entry = tweak_mode_table(io.w_mode.bits.id)
616    write_mode_entry.dec_mode := io.w_mode.bits.dec_mode
617  }
618  io.w_mode.ready := true.B
619
620  // Tweak table read response logic
621  val reg_read_valid    = RegInit(false.B)
622  val reg_tweak_encrpty = RegInit(0.U(128.W))
623  val reg_keyid = RegInit(0.U(KeyIDBits.W))
624  val reg_sel_counter = RegInit(false.B)
625
626  val read_entry = tweak_table(io.req.bits.read_id)
627  val read_mode_entry = tweak_mode_table(io.r_mode.id)
628  io.r_mode.dec_mode := read_mode_entry.dec_mode
629
630  io.req.ready := (!reg_read_valid || (reg_read_valid && io.resp.ready)) && read_entry.v_flag
631  when(io.req.fire) {
632      reg_read_valid          := true.B
633      reg_tweak_encrpty       := read_entry.tweak_encrpty
634      reg_keyid               := read_entry.keyid
635      reg_sel_counter         := read_entry.sel_counter
636      when(read_entry.len === 0.U) {
637          read_entry.v_flag := false.B
638      }.otherwise {
639          when(!read_entry.sel_counter) {
640              read_entry.sel_counter := true.B
641          }.otherwise {
642              read_entry.v_flag := false.B
643          }
644      }
645  }.elsewhen(reg_read_valid && io.resp.ready) {
646      reg_read_valid          := false.B
647  }.otherwise {
648      reg_read_valid          := reg_read_valid
649  }
650
651  io.resp.bits.read_tweak       := reg_tweak_encrpty
652  io.resp.bits.read_keyid       := reg_keyid
653  io.resp.bits.read_sel_counter := reg_sel_counter
654  io.resp.valid     := reg_read_valid
655
656}
657
658
659
660// AXI4Util
661// Bypass routing, Determine the encryption mode in the key expansion module.
662// write requests need to be encrypted ->io.out1;
663// Writing requests does not require encryption --->io.out0.
664class WriteChanelRoute(implicit p: Parameters) extends MemEncryptModule
665{
666  val io = IO(new Bundle {
667    val in = new Bundle {
668      val aw = Flipped(Irrevocable(new AXI4BundleAW(MemcedgeIn.bundle)))
669      val w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeIn.bundle)))
670    }
671    // Unencrypt Chanel
672    val out0 = new Bundle {
673      val aw = Irrevocable(new AXI4BundleAW(MemcedgeIn.bundle))
674      val w = Irrevocable(new AXI4BundleW(MemcedgeIn.bundle))
675    }
676    // Encrypt Chanel
677    val out1 = new Bundle {
678      val aw = Irrevocable(new AXI4BundleAW(MemcedgeIn.bundle))
679      val w = Irrevocable(new AXI4BundleW(MemcedgeIn.bundle))
680    }
681    val enc_keyid = Output(UInt(KeyIDBits.W))
682    val enc_mode = Input(Bool())
683    val memenc_enable = Input(Bool())
684  })
685  io.enc_keyid := io.in.aw.bits.addr(PAddrBits-1, PAddrBits-KeyIDBits)
686
687  val reg_idle = RegInit(true.B)
688  val reg_enc_mode = RegInit(false.B)
689
690  when(io.in.aw.fire) {
691    reg_idle := false.B
692    reg_enc_mode := io.enc_mode && io.memenc_enable
693  }
694  when(io.in.w.fire && io.in.w.bits.last) {
695    reg_idle := true.B
696  }
697
698  val used_enc_mode = Mux(io.in.aw.fire, io.enc_mode && io.memenc_enable, reg_enc_mode)
699
700  // Cut aw_queue.io.enq.ready from io.out*.awready
701  val aw_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in.aw.bits), 1, flow = true))
702
703  io.in.aw.ready := reg_idle && aw_queue.io.enq.ready
704  aw_queue.io.enq.valid := io.in.aw.valid && reg_idle
705  aw_queue.io.enq.bits := io.in.aw.bits
706
707  val unencrypt_aw_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in.aw.bits), MemencPipes+1, flow = true))
708  val unencrypt_w_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in.w.bits), (MemencPipes+1)*2, flow = true))
709
710  aw_queue.io.deq.ready := Mux(used_enc_mode, io.out1.aw.ready, unencrypt_aw_queue.io.enq.ready)
711  io.in.w.ready := (io.in.aw.fire || !reg_idle) && Mux(used_enc_mode, io.out1.w.ready, unencrypt_w_queue.io.enq.ready)
712
713  unencrypt_aw_queue.io.enq.valid := !used_enc_mode && aw_queue.io.deq.valid
714  unencrypt_w_queue.io.enq.valid  := !used_enc_mode && io.in.w.valid  && (io.in.aw.fire || !reg_idle)
715
716  unencrypt_aw_queue.io.enq.bits := aw_queue.io.deq.bits
717  unencrypt_w_queue.io.enq.bits  := io.in.w.bits
718
719  io.out0.aw.valid := unencrypt_aw_queue.io.deq.valid
720  io.out0.w.valid := unencrypt_w_queue.io.deq.valid
721
722  io.out0.aw.bits := unencrypt_aw_queue.io.deq.bits
723  io.out0.w.bits  := unencrypt_w_queue.io.deq.bits
724
725  unencrypt_aw_queue.io.deq.ready := io.out0.aw.ready
726  unencrypt_w_queue.io.deq.ready  := io.out0.w.ready
727
728  io.out1.aw.valid :=  used_enc_mode && aw_queue.io.deq.valid
729  io.out1.w.valid  :=  used_enc_mode && io.in.w.valid  && (io.in.aw.fire || !reg_idle)
730
731  io.out1.aw.bits := aw_queue.io.deq.bits
732  io.out1.w.bits  := io.in.w.bits
733}
734
735class WriteChanelArbiter(implicit p: Parameters) extends MemEncryptModule
736{
737  val io = IO(new Bundle {
738    // Unencrypt Chanel
739    val in0 = new Bundle {
740      val aw = Flipped(Irrevocable(new AXI4BundleAW(MemcedgeOut.bundle)))
741      val w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeOut.bundle)))
742    }
743    // Encrypt Chanel
744    val in1 = new Bundle {
745      val aw = Flipped(Irrevocable(new AXI4BundleAW(MemcedgeOut.bundle)))
746      val w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeOut.bundle)))
747    }
748    val out = new Bundle {
749      val aw = Irrevocable(new AXI4BundleAW(MemcedgeOut.bundle))
750      val w = Irrevocable(new AXI4BundleW(MemcedgeOut.bundle))
751    }
752  })
753
754  val validMask = RegInit(false.B) // 1:last send write req from Encrypt Chanel
755                                   // 0:last send write req from Unencrypt Chanel
756  val aw_choice = Wire(Bool())     // 1:Encrypt Chanel 0:Unencrypt Chanel
757  val w_choice = RegInit(false.B)  // 1:Encrypt Chanel 0:Unencrypt Chanel
758  val reg_idle = RegInit(true.B)
759  // Cut aw_queue.io.enq.ready from io.out*.awready
760  val aw_queue = Module(new IrrevocableQueue(chiselTypeOf(io.in0.aw.bits), 1, flow = true))
761
762  when(io.in1.aw.fire) {
763    validMask := true.B
764  }.elsewhen(io.in0.aw.fire) {
765    validMask := false.B
766  }.otherwise {
767    validMask := validMask
768  }
769
770  // --------------------------[Unencrypt pref]  [Encrypt pref]
771  aw_choice := Mux(validMask, !io.in0.aw.valid, io.in1.aw.valid)
772
773  when(aw_queue.io.enq.fire) {
774    reg_idle := false.B
775    w_choice := aw_choice
776  }
777  when(io.out.w.fire && io.out.w.bits.last) {
778    reg_idle := true.B
779  }
780
781  val used_w_choice = Mux(aw_queue.io.enq.fire, aw_choice, w_choice)
782
783  io.in0.aw.ready := reg_idle && !aw_choice && aw_queue.io.enq.ready
784  io.in1.aw.ready := reg_idle &&  aw_choice && aw_queue.io.enq.ready
785  aw_queue.io.enq.valid := (io.in0.aw.valid || io.in1.aw.valid) && reg_idle
786  aw_queue.io.enq.bits := Mux(aw_choice, io.in1.aw.bits, io.in0.aw.bits)
787
788  // DecoupledIO connect IrrevocableIO
789  io.out.aw.valid := aw_queue.io.deq.valid
790  io.out.aw.bits := aw_queue.io.deq.bits
791  aw_queue.io.deq.ready := io.out.aw.ready
792
793  io.in0.w.ready := (aw_queue.io.enq.fire || !reg_idle) && !used_w_choice && io.out.w.ready
794  io.in1.w.ready := (aw_queue.io.enq.fire || !reg_idle) &&  used_w_choice && io.out.w.ready
795
796  io.out.w.valid := (aw_queue.io.enq.fire || !reg_idle) && Mux(used_w_choice, io.in1.w.valid, io.in0.w.valid)
797  io.out.w.bits  := Mux(used_w_choice, io.in1.w.bits, io.in0.w.bits)
798}
799
800class RdataChanelRoute(implicit p: Parameters) extends MemEncryptModule
801{
802  val io = IO(new Bundle {
803    val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)))
804    // Unencrypt Chanel
805    val out_r0 = Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))
806    // Encrypt Chanel
807    val out_r1 = Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))
808    val dec_rid = Output(UInt(MemcedgeOut.bundle.idBits.W))
809    val dec_mode = Input(Bool())
810  })
811  io.dec_rid := io.in_r.bits.id
812
813  val r_sel = io.dec_mode
814
815  io.out_r0.bits <> io.in_r.bits
816  io.out_r1.bits <> io.in_r.bits
817
818  io.out_r0.valid := io.in_r.valid && !r_sel
819  io.out_r1.valid := io.in_r.valid &&  r_sel
820  io.in_r.ready := Mux(r_sel, io.out_r1.ready, io.out_r0.ready)
821}