xref: /btstack/3rd-party/micro-ecc/scripts/mult_avr.py (revision cd5f23a3250874824c01a2b3326a9522fea3f99f)
1#!/usr/bin/env python3
2
3import sys
4
5if len(sys.argv) < 2:
6    print("Provide the integer size in bytes")
7    sys.exit(1)
8
9size = int(sys.argv[1])
10
11full_rows = size // 10
12init_size = size % 10
13
14if init_size == 0:
15    full_rows = full_rows - 1
16    init_size = 10
17
18def rx(i):
19    return i + 2
20
21def ry(i):
22    return i + 12
23
24def emit(line, *args):
25    s = '"' + line + r' \n\t"'
26    print(s % args)
27
28#### set up registers
29emit("adiw r30, %s", size - init_size) # move z
30emit("adiw r28, %s", size - init_size) # move y
31
32for i in range(init_size):
33    emit("ld r%s, x+", rx(i))
34for i in range(init_size):
35    emit("ld r%s, y+", ry(i))
36
37emit("ldi r25, 0")
38print("")
39if init_size == 1:
40    emit("mul r2, r12")
41    emit("st z+, r0")
42    emit("st z+, r1")
43else:
44    #### first two multiplications of initial block
45    emit("ldi r23, 0")
46    emit("mul r2, r12")
47    emit("st z+, r0")
48    emit("mov r22, r1")
49    print("")
50    emit("ldi r24, 0")
51    emit("mul r2, r13")
52    emit("add r22, r0")
53    emit("adc r23, r1")
54    emit("mul r3, r12")
55    emit("add r22, r0")
56    emit("adc r23, r1")
57    emit("adc r24, r25")
58    emit("st z+, r22")
59    print("")
60
61    #### rest of initial block, with moving accumulator registers
62    acc = [23, 24, 22]
63    for r in range(2, init_size):
64        emit("ldi r%s, 0", acc[2])
65        for i in range(0, r+1):
66            emit("mul r%s, r%s", rx(i), ry(r - i))
67            emit("add r%s, r0", acc[0])
68            emit("adc r%s, r1", acc[1])
69            emit("adc r%s, r25", acc[2])
70        emit("st z+, r%s", acc[0])
71        print("")
72        acc = acc[1:] + acc[:1]
73    for r in range(1, init_size-1):
74        emit("ldi r%s, 0", acc[2])
75        for i in range(0, init_size-r):
76            emit("mul r%s, r%s", rx(r+i), ry((init_size-1) - i))
77            emit("add r%s, r0", acc[0])
78            emit("adc r%s, r1", acc[1])
79            emit("adc r%s, r25", acc[2])
80        emit("st z+, r%s", acc[0])
81        print("")
82        acc = acc[1:] + acc[:1]
83    emit("mul r%s, r%s", rx(init_size-1), ry(init_size-1))
84    emit("add r%s, r0", acc[0])
85    emit("adc r%s, r1", acc[1])
86    emit("st z+, r%s", acc[0])
87    emit("st z+, r%s", acc[1])
88print("")
89
90#### reset y and z pointers
91emit("sbiw r30, %s", 2 * init_size + 10)
92emit("sbiw r28, %s", init_size + 10)
93
94#### load y registers
95for i in range(10):
96    emit("ld r%s, y+", ry(i))
97
98#### load additional x registers
99for i in range(init_size, 10):
100    emit("ld r%s, x+", rx(i))
101print("")
102
103prev_size = init_size
104for row in range(full_rows):
105    #### do x = 0-9, y = 0-9 multiplications
106    emit("ldi r23, 0")
107    emit("mul r2, r12")
108    emit("st z+, r0")
109    emit("mov r22, r1")
110    print("")
111    emit("ldi r24, 0")
112    emit("mul r2, r13")
113    emit("add r22, r0")
114    emit("adc r23, r1")
115    emit("mul r3, r12")
116    emit("add r22, r0")
117    emit("adc r23, r1")
118    emit("adc r24, r25")
119    emit("st z+, r22")
120    print("")
121
122    acc = [23, 24, 22]
123    for r in range(2, 10):
124        emit("ldi r%s, 0", acc[2])
125        for i in range(0, r+1):
126            emit("mul r%s, r%s", rx(i), ry(r - i))
127            emit("add r%s, r0", acc[0])
128            emit("adc r%s, r1", acc[1])
129            emit("adc r%s, r25", acc[2])
130        emit("st z+, r%s", acc[0])
131        print("")
132        acc = acc[1:] + acc[:1]
133
134    #### now we need to start shifting x and loading from z
135    x_regs = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
136    for r in range(0, prev_size):
137        x_regs = x_regs[1:] + x_regs[:1]
138        emit("ld r%s, x+", x_regs[9]) # load next byte of left
139        emit("ldi r%s, 0", acc[2])
140        for i in range(0, 10):
141            emit("mul r%s, r%s", x_regs[i], ry(9 - i))
142            emit("add r%s, r0", acc[0])
143            emit("adc r%s, r1", acc[1])
144            emit("adc r%s, r25", acc[2])
145        emit("ld r0, z") # load stored value from initial block, and add to accumulator (note z does not increment)
146        emit("add r%s, r0", acc[0])
147        emit("adc r%s, r25", acc[1])
148        emit("adc r%s, r25", acc[2])
149        emit("st z+, r%s", acc[0]) # store next byte (z increments)
150        print("")
151        acc = acc[1:] + acc[:1]
152
153    # done shifting x, start shifting y
154    y_regs = [12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
155    for r in range(0, prev_size):
156        y_regs = y_regs[1:] + y_regs[:1]
157        emit("ld r%s, y+", y_regs[9]) # load next byte of right
158        emit("ldi r%s, 0", acc[2])
159        for i in range(0, 10):
160            emit("mul r%s, r%s", x_regs[i], y_regs[9 -i])
161            emit("add r%s, r0", acc[0])
162            emit("adc r%s, r1", acc[1])
163            emit("adc r%s, r25", acc[2])
164        emit("ld r0, z") # load stored value from initial block, and add to accumulator (note z does not increment)
165        emit("add r%s, r0", acc[0])
166        emit("adc r%s, r25", acc[1])
167        emit("adc r%s, r25", acc[2])
168        emit("st z+, r%s", acc[0]) # store next byte (z increments)
169        print("")
170        acc = acc[1:] + acc[:1]
171
172    # done both shifts, do remaining corner
173    for r in range(1, 9):
174        emit("ldi r%s, 0", acc[2])
175        for i in range(0, 10-r):
176            emit("mul r%s, r%s", x_regs[r+i], y_regs[9 - i])
177            emit("add r%s, r0", acc[0])
178            emit("adc r%s, r1", acc[1])
179            emit("adc r%s, r25", acc[2])
180        emit("st z+, r%s", acc[0])
181        print("")
182        acc = acc[1:] + acc[:1]
183    emit("mul r%s, r%s", x_regs[9], y_regs[9])
184    emit("add r%s, r0", acc[0])
185    emit("adc r%s, r1", acc[1])
186    emit("st z+, r%s", acc[0])
187    emit("st z+, r%s", acc[1])
188    print("")
189
190    prev_size = prev_size + 10
191    if row < full_rows - 1:
192        #### reset x, y and z pointers
193        emit("sbiw r30, %s", 2 * prev_size + 10)
194        emit("sbiw r28, %s", prev_size + 10)
195        emit("sbiw r26, %s", prev_size)
196
197        #### load x and y registers
198        for i in range(10):
199            emit("ld r%s, x+", rx(i))
200            emit("ld r%s, y+", ry(i))
201        print("")
202
203emit("eor r1, r1")
204