1#!/usr/bin/env python3 2 3import sys 4 5if len(sys.argv) < 2: 6 print("Provide the integer size in bytes") 7 sys.exit(1) 8 9size = int(sys.argv[1]) 10 11if size > 40: 12 print("This script doesn't work with integer size %s due to laziness" % (size)) 13 sys.exit(1) 14 15init_size = size - 20 16if size < 20: 17 init_size = 0 18 19def rg(i): 20 return i + 2 21 22def lo(i): 23 return i + 2 24 25def hi(i): 26 return i + 12 27 28def emit(line, *args): 29 s = '"' + line + r' \n\t"' 30 print(s % args) 31 32#### set up registers 33zero = "r25" 34emit("ldi %s, 0", zero) # zero register 35 36if init_size > 0: 37 emit("movw r28, r26") # y = x 38 h = (init_size + 1)//2 39 40 for i in range(h): 41 emit("ld r%s, x+", lo(i)) 42 emit("adiw r28, %s", size - init_size) # move y to other end 43 for i in range(h): 44 emit("ld r%s, y+", hi(i)) 45 46 emit("adiw r30, %s", size - init_size) # move z 47 48 if init_size == 1: 49 emit("mul %s, %s", lo(0), hi(0)) 50 emit("st z+, r0") 51 emit("st z+, r1") 52 else: 53 #### first one 54 print("") 55 emit("ldi r23, 0") 56 emit("mul %s, %s", lo(0), hi(0)) 57 emit("st z+, r0") 58 emit("mov r22, r1") 59 print("") 60 61 #### rest of initial block, with moving accumulator registers 62 acc = [22, 23, 24] 63 for r in range(1, h): 64 emit("ldi r%s, 0", acc[2]) 65 for i in range(0, (r+2)//2): 66 emit("mul r%s, r%s", lo(i), hi(r - i)) 67 emit("add r%s, r0", acc[0]) 68 emit("adc r%s, r1", acc[1]) 69 emit("adc r%s, %s", acc[2], zero) 70 emit("st z+, r%s", acc[0]) 71 print("") 72 acc = acc[1:] + acc[:1] 73 74 lo_r = range(2, 2 + h) 75 hi_r = range(12, 12 + h) 76 77 # now we need to start loading more from the high end 78 for r in range(h, init_size): 79 hi_r = hi_r[1:] + hi_r[:1] 80 emit("ld r%s, y+", hi_r[h-1]) 81 82 emit("ldi r%s, 0", acc[2]) 83 for i in range(0, (r+2)//2): 84 emit("mul r%s, r%s", lo(i), hi_r[h - 1 - i]) 85 emit("add r%s, r0", acc[0]) 86 emit("adc r%s, r1", acc[1]) 87 emit("adc r%s, %s", acc[2], zero) 88 emit("st z+, r%s", acc[0]) 89 print("") 90 acc = acc[1:] + acc[:1] 91 92 # loaded all of the high end bytes; now need to start loading the rest of the low end 93 for r in range(1, init_size-h): 94 lo_r = lo_r[1:] + lo_r[:1] 95 emit("ld r%s, x+", lo_r[h-1]) 96 97 emit("ldi r%s, 0", acc[2]) 98 for i in range(0, (init_size+1 - r)//2): 99 emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i]) 100 emit("add r%s, r0", acc[0]) 101 emit("adc r%s, r1", acc[1]) 102 emit("adc r%s, %s", acc[2], zero) 103 emit("st z+, r%s", acc[0]) 104 print("") 105 acc = acc[1:] + acc[:1] 106 107 lo_r = lo_r[1:] + lo_r[:1] 108 emit("ld r%s, x+", lo_r[h-1]) 109 110 # now we have loaded everything, and we just need to finish the last corner 111 for r in range(init_size-h, init_size-1): 112 emit("ldi r%s, 0", acc[2]) 113 for i in range(0, (init_size+1 - r)//2): 114 emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i]) 115 emit("add r%s, r0", acc[0]) 116 emit("adc r%s, r1", acc[1]) 117 emit("adc r%s, %s", acc[2], zero) 118 emit("st z+, r%s", acc[0]) 119 print("") 120 acc = acc[1:] + acc[:1] 121 lo_r = lo_r[1:] + lo_r[:1] # make the indexing easy 122 123 emit("mul r%s, r%s", lo_r[0], hi_r[h - 1]) 124 emit("add r%s, r0", acc[0]) 125 emit("adc r%s, r1", acc[1]) 126 emit("st z+, r%s", acc[0]) 127 emit("st z+, r%s", acc[1]) 128 print("") 129 emit("sbiw r26, %s", init_size) # reset x 130 emit("sbiw r30, %s", size + init_size) # reset z 131 132# TODO you could do more rows of size 20 here if your integers are larger than 40 bytes 133 134s = size - init_size 135 136for i in range(s): 137 emit("ld r%s, x+", rg(i)) 138 139#### first few columns 140# NOTE: this is only valid if size >= 3 141print("") 142emit("ldi r23, 0") 143emit("mul r%s, r%s", rg(0), rg(0)) 144emit("st z+, r0") 145emit("mov r22, r1") 146print("") 147emit("ldi r24, 0") 148emit("mul r%s, r%s", rg(0), rg(1)) 149emit("add r22, r0") 150emit("adc r23, r1") 151emit("adc r24, %s", zero) 152emit("add r22, r0") 153emit("adc r23, r1") 154emit("adc r24, %s", zero) 155emit("st z+, r22") 156print("") 157emit("ldi r22, 0") 158emit("mul r%s, r%s", rg(0), rg(2)) 159emit("add r23, r0") 160emit("adc r24, r1") 161emit("adc r22, %s", zero) 162emit("add r23, r0") 163emit("adc r24, r1") 164emit("adc r22, %s", zero) 165emit("mul r%s, r%s", rg(1), rg(1)) 166emit("add r23, r0") 167emit("adc r24, r1") 168emit("adc r22, %s", zero) 169emit("st z+, r23") 170print("") 171 172acc = [23, 24, 22] 173old_acc = [28, 29] 174for i in range(3, s): 175 emit("ldi r%s, 0", old_acc[1]) 176 tmp = [acc[1], acc[2]] 177 acc = [acc[0], old_acc[0], old_acc[1]] 178 old_acc = tmp 179 180 # gather non-equal words 181 emit("mul r%s, r%s", rg(0), rg(i)) 182 emit("mov r%s, r0", acc[0]) 183 emit("mov r%s, r1", acc[1]) 184 for j in range(1, (i+1)//2): 185 emit("mul r%s, r%s", rg(j), rg(i-j)) 186 emit("add r%s, r0", acc[0]) 187 emit("adc r%s, r1", acc[1]) 188 emit("adc r%s, %s", acc[2], zero) 189 # multiply by 2 190 emit("lsl r%s", acc[0]) 191 emit("rol r%s", acc[1]) 192 emit("rol r%s", acc[2]) 193 194 # add equal word (if any) 195 if ((i+1) % 2) != 0: 196 emit("mul r%s, r%s", rg(i//2), rg(i//2)) 197 emit("add r%s, r0", acc[0]) 198 emit("adc r%s, r1", acc[1]) 199 emit("adc r%s, %s", acc[2], zero) 200 201 # add old accumulator 202 emit("add r%s, r%s", acc[0], old_acc[0]) 203 emit("adc r%s, r%s", acc[1], old_acc[1]) 204 emit("adc r%s, %s", acc[2], zero) 205 206 # store 207 emit("st z+, r%s", acc[0]) 208 print("") 209 210regs = range(2, 22) 211for i in range(init_size): 212 regs = regs[1:] + regs[:1] 213 emit("ld r%s, x+", regs[19]) 214 215 for limit in [18, 19]: 216 emit("ldi r%s, 0", old_acc[1]) 217 tmp = [acc[1], acc[2]] 218 acc = [acc[0], old_acc[0], old_acc[1]] 219 old_acc = tmp 220 221 # gather non-equal words 222 emit("mul r%s, r%s", regs[0], regs[limit]) 223 emit("mov r%s, r0", acc[0]) 224 emit("mov r%s, r1", acc[1]) 225 for j in range(1, (limit+1)//2): 226 emit("mul r%s, r%s", regs[j], regs[limit-j]) 227 emit("add r%s, r0", acc[0]) 228 emit("adc r%s, r1", acc[1]) 229 emit("adc r%s, %s", acc[2], zero) 230 231 emit("ld r0, z") # load stored value from initial block, and add to accumulator (note z does not increment) 232 emit("add r%s, r0", acc[0]) 233 emit("adc r%s, r25", acc[1]) 234 emit("adc r%s, r25", acc[2]) 235 236 # multiply by 2 237 emit("lsl r%s", acc[0]) 238 emit("rol r%s", acc[1]) 239 emit("rol r%s", acc[2]) 240 241 # add equal word 242 if limit == 18: 243 emit("mul r%s, r%s", regs[9], regs[9]) 244 emit("add r%s, r0", acc[0]) 245 emit("adc r%s, r1", acc[1]) 246 emit("adc r%s, %s", acc[2], zero) 247 248 # add old accumulator 249 emit("add r%s, r%s", acc[0], old_acc[0]) 250 emit("adc r%s, r%s", acc[1], old_acc[1]) 251 emit("adc r%s, %s", acc[2], zero) 252 253 # store 254 emit("st z+, r%s", acc[0]) 255 print("") 256 257for i in range(1, s-3): 258 emit("ldi r%s, 0", old_acc[1]) 259 tmp = [acc[1], acc[2]] 260 acc = [acc[0], old_acc[0], old_acc[1]] 261 old_acc = tmp 262 263 # gather non-equal words 264 emit("mul r%s, r%s", regs[i], regs[s - 1]) 265 emit("mov r%s, r0", acc[0]) 266 emit("mov r%s, r1", acc[1]) 267 for j in range(1, (s-i)//2): 268 emit("mul r%s, r%s", regs[i+j], regs[s - 1 - j]) 269 emit("add r%s, r0", acc[0]) 270 emit("adc r%s, r1", acc[1]) 271 emit("adc r%s, %s", acc[2], zero) 272 # multiply by 2 273 emit("lsl r%s", acc[0]) 274 emit("rol r%s", acc[1]) 275 emit("rol r%s", acc[2]) 276 277 # add equal word (if any) 278 if ((s-i) % 2) != 0: 279 emit("mul r%s, r%s", regs[i + (s-i)//2], regs[i + (s-i)//2]) 280 emit("add r%s, r0", acc[0]) 281 emit("adc r%s, r1", acc[1]) 282 emit("adc r%s, %s", acc[2], zero) 283 284 # add old accumulator 285 emit("add r%s, r%s", acc[0], old_acc[0]) 286 emit("adc r%s, r%s", acc[1], old_acc[1]) 287 emit("adc r%s, %s", acc[2], zero) 288 289 # store 290 emit("st z+, r%s", acc[0]) 291 print("") 292 293acc = acc[1:] + acc[:1] 294emit("ldi r%s, 0", acc[2]) 295emit("mul r%s, r%s", regs[17], regs[19]) 296emit("add r%s, r0", acc[0]) 297emit("adc r%s, r1", acc[1]) 298emit("adc r%s, %s", acc[2], zero) 299emit("add r%s, r0", acc[0]) 300emit("adc r%s, r1", acc[1]) 301emit("adc r%s, %s", acc[2], zero) 302emit("mul r%s, r%s", regs[18], regs[18]) 303emit("add r%s, r0", acc[0]) 304emit("adc r%s, r1", acc[1]) 305emit("adc r%s, %s", acc[2], zero) 306emit("st z+, r%s", acc[0]) 307print("") 308 309acc = acc[1:] + acc[:1] 310emit("ldi r%s, 0", acc[2]) 311emit("mul r%s, r%s", regs[18], regs[19]) 312emit("add r%s, r0", acc[0]) 313emit("adc r%s, r1", acc[1]) 314emit("adc r%s, %s", acc[2], zero) 315emit("add r%s, r0", acc[0]) 316emit("adc r%s, r1", acc[1]) 317emit("adc r%s, %s", acc[2], zero) 318emit("st z+, r%s", acc[0]) 319print("") 320 321emit("mul r%s, r%s", regs[19], regs[19]) 322emit("add r%s, r0", acc[1]) 323emit("adc r%s, r1", acc[2]) 324emit("st z+, r%s", acc[1]) 325 326emit("st z+, r%s", acc[2]) 327emit("eor r1, r1") 328