xref: /btstack/3rd-party/micro-ecc/scripts/square_avr.py (revision cd5f23a3250874824c01a2b3326a9522fea3f99f)
1#!/usr/bin/env python3
2
3import sys
4
5if len(sys.argv) < 2:
6    print("Provide the integer size in bytes")
7    sys.exit(1)
8
9size = int(sys.argv[1])
10
11if size > 40:
12    print("This script doesn't work with integer size %s due to laziness" % (size))
13    sys.exit(1)
14
15init_size = size - 20
16if size < 20:
17    init_size = 0
18
19def rg(i):
20    return i + 2
21
22def lo(i):
23    return i + 2
24
25def hi(i):
26    return i + 12
27
28def emit(line, *args):
29    s = '"' + line + r' \n\t"'
30    print(s % args)
31
32#### set up registers
33zero = "r25"
34emit("ldi %s, 0", zero) # zero register
35
36if init_size > 0:
37    emit("movw r28, r26") # y = x
38    h = (init_size + 1)//2
39
40    for i in range(h):
41        emit("ld r%s, x+", lo(i))
42    emit("adiw r28, %s", size - init_size) # move y to other end
43    for i in range(h):
44        emit("ld r%s, y+", hi(i))
45
46    emit("adiw r30, %s", size - init_size) # move z
47
48    if init_size == 1:
49        emit("mul %s, %s", lo(0), hi(0))
50        emit("st z+, r0")
51        emit("st z+, r1")
52    else:
53        #### first one
54        print("")
55        emit("ldi r23, 0")
56        emit("mul %s, %s", lo(0), hi(0))
57        emit("st z+, r0")
58        emit("mov r22, r1")
59        print("")
60
61        #### rest of initial block, with moving accumulator registers
62        acc = [22, 23, 24]
63        for r in range(1, h):
64            emit("ldi r%s, 0", acc[2])
65            for i in range(0, (r+2)//2):
66                emit("mul r%s, r%s", lo(i), hi(r - i))
67                emit("add r%s, r0", acc[0])
68                emit("adc r%s, r1", acc[1])
69                emit("adc r%s, %s", acc[2], zero)
70            emit("st z+, r%s", acc[0])
71            print("")
72            acc = acc[1:] + acc[:1]
73
74        lo_r = range(2, 2 + h)
75        hi_r = range(12, 12 + h)
76
77        # now we need to start loading more from the high end
78        for r in range(h, init_size):
79            hi_r = hi_r[1:] + hi_r[:1]
80            emit("ld r%s, y+", hi_r[h-1])
81
82            emit("ldi r%s, 0", acc[2])
83            for i in range(0, (r+2)//2):
84                emit("mul r%s, r%s", lo(i), hi_r[h - 1 - i])
85                emit("add r%s, r0", acc[0])
86                emit("adc r%s, r1", acc[1])
87                emit("adc r%s, %s", acc[2], zero)
88            emit("st z+, r%s", acc[0])
89            print("")
90            acc = acc[1:] + acc[:1]
91
92        # loaded all of the high end bytes; now need to start loading the rest of the low end
93        for r in range(1, init_size-h):
94            lo_r = lo_r[1:] + lo_r[:1]
95            emit("ld r%s, x+", lo_r[h-1])
96
97            emit("ldi r%s, 0", acc[2])
98            for i in range(0, (init_size+1 - r)//2):
99                emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i])
100                emit("add r%s, r0", acc[0])
101                emit("adc r%s, r1", acc[1])
102                emit("adc r%s, %s", acc[2], zero)
103            emit("st z+, r%s", acc[0])
104            print("")
105            acc = acc[1:] + acc[:1]
106
107        lo_r = lo_r[1:] + lo_r[:1]
108        emit("ld r%s, x+", lo_r[h-1])
109
110        # now we have loaded everything, and we just need to finish the last corner
111        for r in range(init_size-h, init_size-1):
112            emit("ldi r%s, 0", acc[2])
113            for i in range(0, (init_size+1 - r)//2):
114                emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i])
115                emit("add r%s, r0", acc[0])
116                emit("adc r%s, r1", acc[1])
117                emit("adc r%s, %s", acc[2], zero)
118            emit("st z+, r%s", acc[0])
119            print("")
120            acc = acc[1:] + acc[:1]
121            lo_r = lo_r[1:] + lo_r[:1] # make the indexing easy
122
123        emit("mul r%s, r%s", lo_r[0], hi_r[h - 1])
124        emit("add r%s, r0", acc[0])
125        emit("adc r%s, r1", acc[1])
126        emit("st z+, r%s", acc[0])
127        emit("st z+, r%s", acc[1])
128    print("")
129    emit("sbiw r26, %s", init_size) # reset x
130    emit("sbiw r30, %s", size + init_size) # reset z
131
132# TODO you could do more rows of size 20 here if your integers are larger than 40 bytes
133
134s = size - init_size
135
136for i in range(s):
137    emit("ld r%s, x+", rg(i))
138
139#### first few columns
140# NOTE: this is only valid if size >= 3
141print("")
142emit("ldi r23, 0")
143emit("mul r%s, r%s", rg(0), rg(0))
144emit("st z+, r0")
145emit("mov r22, r1")
146print("")
147emit("ldi r24, 0")
148emit("mul r%s, r%s", rg(0), rg(1))
149emit("add r22, r0")
150emit("adc r23, r1")
151emit("adc r24, %s", zero)
152emit("add r22, r0")
153emit("adc r23, r1")
154emit("adc r24, %s", zero)
155emit("st z+, r22")
156print("")
157emit("ldi r22, 0")
158emit("mul r%s, r%s", rg(0), rg(2))
159emit("add r23, r0")
160emit("adc r24, r1")
161emit("adc r22, %s", zero)
162emit("add r23, r0")
163emit("adc r24, r1")
164emit("adc r22, %s", zero)
165emit("mul r%s, r%s", rg(1), rg(1))
166emit("add r23, r0")
167emit("adc r24, r1")
168emit("adc r22, %s", zero)
169emit("st z+, r23")
170print("")
171
172acc = [23, 24, 22]
173old_acc = [28, 29]
174for i in range(3, s):
175    emit("ldi r%s, 0", old_acc[1])
176    tmp = [acc[1], acc[2]]
177    acc = [acc[0], old_acc[0], old_acc[1]]
178    old_acc = tmp
179
180    # gather non-equal words
181    emit("mul r%s, r%s", rg(0), rg(i))
182    emit("mov r%s, r0", acc[0])
183    emit("mov r%s, r1", acc[1])
184    for j in range(1, (i+1)//2):
185        emit("mul r%s, r%s", rg(j), rg(i-j))
186        emit("add r%s, r0", acc[0])
187        emit("adc r%s, r1", acc[1])
188        emit("adc r%s, %s", acc[2], zero)
189    # multiply by 2
190    emit("lsl r%s", acc[0])
191    emit("rol r%s", acc[1])
192    emit("rol r%s", acc[2])
193
194    # add equal word (if any)
195    if ((i+1) % 2) != 0:
196        emit("mul r%s, r%s", rg(i//2), rg(i//2))
197        emit("add r%s, r0", acc[0])
198        emit("adc r%s, r1", acc[1])
199        emit("adc r%s, %s", acc[2], zero)
200
201    # add old accumulator
202    emit("add r%s, r%s", acc[0], old_acc[0])
203    emit("adc r%s, r%s", acc[1], old_acc[1])
204    emit("adc r%s, %s", acc[2], zero)
205
206    # store
207    emit("st z+, r%s", acc[0])
208    print("")
209
210regs = range(2, 22)
211for i in range(init_size):
212    regs = regs[1:] + regs[:1]
213    emit("ld r%s, x+", regs[19])
214
215    for limit in [18, 19]:
216        emit("ldi r%s, 0", old_acc[1])
217        tmp = [acc[1], acc[2]]
218        acc = [acc[0], old_acc[0], old_acc[1]]
219        old_acc = tmp
220
221        # gather non-equal words
222        emit("mul r%s, r%s", regs[0], regs[limit])
223        emit("mov r%s, r0", acc[0])
224        emit("mov r%s, r1", acc[1])
225        for j in range(1, (limit+1)//2):
226            emit("mul r%s, r%s", regs[j], regs[limit-j])
227            emit("add r%s, r0", acc[0])
228            emit("adc r%s, r1", acc[1])
229            emit("adc r%s, %s", acc[2], zero)
230
231        emit("ld r0, z") # load stored value from initial block, and add to accumulator (note z does not increment)
232        emit("add r%s, r0", acc[0])
233        emit("adc r%s, r25", acc[1])
234        emit("adc r%s, r25", acc[2])
235
236        # multiply by 2
237        emit("lsl r%s", acc[0])
238        emit("rol r%s", acc[1])
239        emit("rol r%s", acc[2])
240
241        # add equal word
242        if limit == 18:
243            emit("mul r%s, r%s", regs[9], regs[9])
244            emit("add r%s, r0", acc[0])
245            emit("adc r%s, r1", acc[1])
246            emit("adc r%s, %s", acc[2], zero)
247
248        # add old accumulator
249        emit("add r%s, r%s", acc[0], old_acc[0])
250        emit("adc r%s, r%s", acc[1], old_acc[1])
251        emit("adc r%s, %s", acc[2], zero)
252
253        # store
254        emit("st z+, r%s", acc[0])
255        print("")
256
257for i in range(1, s-3):
258    emit("ldi r%s, 0", old_acc[1])
259    tmp = [acc[1], acc[2]]
260    acc = [acc[0], old_acc[0], old_acc[1]]
261    old_acc = tmp
262
263    # gather non-equal words
264    emit("mul r%s, r%s", regs[i], regs[s - 1])
265    emit("mov r%s, r0", acc[0])
266    emit("mov r%s, r1", acc[1])
267    for j in range(1, (s-i)//2):
268        emit("mul r%s, r%s", regs[i+j], regs[s - 1 - j])
269        emit("add r%s, r0", acc[0])
270        emit("adc r%s, r1", acc[1])
271        emit("adc r%s, %s", acc[2], zero)
272    # multiply by 2
273    emit("lsl r%s", acc[0])
274    emit("rol r%s", acc[1])
275    emit("rol r%s", acc[2])
276
277    # add equal word (if any)
278    if ((s-i) % 2) != 0:
279        emit("mul r%s, r%s", regs[i + (s-i)//2], regs[i + (s-i)//2])
280        emit("add r%s, r0", acc[0])
281        emit("adc r%s, r1", acc[1])
282        emit("adc r%s, %s", acc[2], zero)
283
284    # add old accumulator
285    emit("add r%s, r%s", acc[0], old_acc[0])
286    emit("adc r%s, r%s", acc[1], old_acc[1])
287    emit("adc r%s, %s", acc[2], zero)
288
289    # store
290    emit("st z+, r%s", acc[0])
291    print("")
292
293acc = acc[1:] + acc[:1]
294emit("ldi r%s, 0", acc[2])
295emit("mul r%s, r%s", regs[17], regs[19])
296emit("add r%s, r0", acc[0])
297emit("adc r%s, r1", acc[1])
298emit("adc r%s, %s", acc[2], zero)
299emit("add r%s, r0", acc[0])
300emit("adc r%s, r1", acc[1])
301emit("adc r%s, %s", acc[2], zero)
302emit("mul r%s, r%s", regs[18], regs[18])
303emit("add r%s, r0", acc[0])
304emit("adc r%s, r1", acc[1])
305emit("adc r%s, %s", acc[2], zero)
306emit("st z+, r%s", acc[0])
307print("")
308
309acc = acc[1:] + acc[:1]
310emit("ldi r%s, 0", acc[2])
311emit("mul r%s, r%s", regs[18], regs[19])
312emit("add r%s, r0", acc[0])
313emit("adc r%s, r1", acc[1])
314emit("adc r%s, %s", acc[2], zero)
315emit("add r%s, r0", acc[0])
316emit("adc r%s, r1", acc[1])
317emit("adc r%s, %s", acc[2], zero)
318emit("st z+, r%s", acc[0])
319print("")
320
321emit("mul r%s, r%s", regs[19], regs[19])
322emit("add r%s, r0", acc[1])
323emit("adc r%s, r1", acc[2])
324emit("st z+, r%s", acc[1])
325
326emit("st z+, r%s", acc[2])
327emit("eor r1, r1")
328