xref: /aosp_15_r20/external/gemmlowp/meta/generators/transform_kernels_common.py (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""."""
15
16import common
17
18
19def _DuplicateGeneralRegister(size, emitter, registers, value, min_register):
20  register = registers.QuadRegister(min_register)
21  emitter.EmitVDup(size, register, value)
22  return register
23
24
25def _DuplicateGeneralMemoryRegister(size, emitter, registers, value,
26                                    min_register):
27  register = registers.QuadRegister(min_register)
28  general = registers.GeneralRegister()
29  emitter.EmitLdr(general, value)
30  emitter.EmitVDup(size, register, general)
31  registers.FreeRegister(general)
32  return register
33
34
35class MinMaxTransformation(object):
36  """."""
37
38  def Check(self, in_type, out_type, kernel_size, leftovers):
39    assert in_type is 'uint8_t'
40    assert out_type is 'uint8_t'
41    assert kernel_size is 16
42    assert leftovers < 16
43
44  def Prepare(self, emitter, registers, unused_kernel_size):
45    emitter.EmitNewline()
46    emitter.EmitComment('MinMax::Prepare')
47
48    self.min = _DuplicateGeneralRegister(8, emitter, registers,
49                                         registers.MapParameter('min',
50                                                                'params.min'),
51                                         4)
52    self.max = _DuplicateGeneralRegister(8, emitter, registers,
53                                         registers.MapParameter('max',
54                                                                'params.max'),
55                                         4)
56
57  def Transform(self, emitter, registers, input_address, elements,
58                output_address):
59    """Generate the MinMax transform inner loop code."""
60    emitter.EmitNewline()
61    emitter.EmitComment('MinMax::Transform')
62    register_count = (elements + 15) / 16
63    load = [registers.QuadRegister() for unused_i in range(register_count)]
64    emitter.EmitVLoadAE(8, elements, load, input_address, None)
65    emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(16))
66
67    for register in load:
68      emitter.EmitVMax('u8', register, register, self.min)
69
70    for register in load:
71      emitter.EmitVMin('u8', register, register, self.max)
72
73    emitter.EmitNewline()
74    emitter.EmitVStoreAE(8, elements, load, output_address, None)
75    emitter.EmitPld(output_address)
76    registers.FreeRegisters(load)
77
78
79class DequantizeTransformation(object):
80  """."""
81
82  def Check(self, in_type, out_type, kernel_size, leftovers):
83    assert in_type is 'uint8_t'
84    assert out_type is 'float'
85    assert kernel_size is 16
86    assert leftovers < 16
87
88  def Prepare(self, emitter, registers, unused_kernel_size):
89    """Duplicate quantization offsets to vector registers."""
90    emitter.EmitNewline()
91    emitter.EmitComment('Dequantize::Prepare')
92
93    self.range_min = _DuplicateGeneralRegister(
94        32, emitter, registers,
95        registers.MapParameter('range_min', 'params.range_min'), 4)
96    self.range_offset = _DuplicateGeneralRegister(
97        32, emitter, registers,
98        registers.MapParameter('range_offset', 'params.range_offset'), 4)
99    self.range_scale = _DuplicateGeneralRegister(
100        32, emitter, registers,
101        registers.MapParameter('range_scale', 'params.range_scale'), 4)
102
103  def Transform(self, emitter, registers, input_address, elements,
104                output_address):
105    """Emit the dequantization inner loop."""
106    emitter.EmitNewline()
107    emitter.EmitComment('Dequantize::Transform')
108    register_count = (elements + 3) / 4
109    load = [registers.QuadRegister() for unused_i in range(register_count)]
110    emitter.EmitVLoadAE(8, elements, load, input_address, None)
111    emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(32))
112
113    if len(load) is 1:
114      emitter.EmitVMovl('u8', load[0], load[0])
115      emitter.EmitVMovl('s16', load[0], load[0])
116    elif len(load) is 2:
117      emitter.EmitVMovl('u8', load[0], load[0])
118      emitter.EmitVMovl2('s16', load[0], load[1], load[0])
119    elif len(load) is 3:
120      emitter.EmitVMovl2('u8', load[0], load[1], load[0])
121      emitter.EmitVMovl('s16', load[2], load[1])
122      emitter.EmitVMovl2('s16', load[0], load[1], load[0])
123    elif len(load) is 4:
124      emitter.EmitVMovl2('u8', load[0], load[1], load[0])
125      emitter.EmitVMovl2('s16', load[2], load[3], load[1])
126      emitter.EmitVMovl2('s16', load[0], load[1], load[0])
127    else:
128      assert False
129
130    for register in load:
131      emitter.EmitVCvt('f32', 's32', register, register)
132
133    for register in load:
134      emitter.EmitVSub('f32', register, register, self.range_offset)
135
136    for register in load:
137      emitter.EmitVMul('f32', register, register, self.range_scale)
138
139    for register in load:
140      emitter.EmitVAdd('f32', register, register, self.range_min)
141
142    emitter.EmitNewline()
143    emitter.EmitVStoreAE(32, elements, load, output_address, None)
144    emitter.EmitPld(output_address)
145    registers.FreeRegisters(load)
146
147
148class QuantizeTransformation(object):
149  """."""
150
151  def Check(self, in_type, out_type, kernel_size, leftovers):
152    assert in_type is 'float'
153    assert out_type is 'uint8_t'
154    assert kernel_size is 16
155    assert leftovers < 16
156
157  def Prepare(self, emitter, registers, unused_kernel_size):
158    """Duplicate quantization offsets to vector registers."""
159    emitter.EmitNewline()
160    emitter.EmitComment('Quantize::Prepare')
161
162    self.range_min = _DuplicateGeneralRegister(
163        32, emitter, registers,
164        registers.MapParameter('range_min', 'params.range_min'), 4)
165    self.range_offset = _DuplicateGeneralRegister(
166        32, emitter, registers,
167        registers.MapParameter('range_offset', 'params.range_offset'), 4)
168    self.range_scale = _DuplicateGeneralRegister(
169        32, emitter, registers,
170        registers.MapParameter('range_scale', 'params.range_scale'), 4)
171
172  def Transform(self, emitter, registers, input_address, elements,
173                output_address):
174    """Emit quantization inner loop code."""
175    emitter.EmitNewline()
176    emitter.EmitComment('Quantize::Transform')
177    register_count = (elements + 3) / 4
178    load = [registers.QuadRegister() for unused_i in range(register_count)]
179    emitter.EmitVLoadAE(32, elements, load, input_address, None)
180    emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(64))
181
182    for register in load:
183      emitter.EmitVSub('f32', register, register, self.range_min)
184
185    for register in load:
186      emitter.EmitVMul('f32', register, register, self.range_scale)
187
188    for register in load:
189      emitter.EmitVAdd('f32', register, register, self.range_offset)
190
191    for register in load:
192      emitter.EmitVCvt('s32', 'f32', register, register)
193
194    if len(load) is 1:
195      emitter.EmitVQmovn('s32', load[0], load[0])
196      emitter.EmitVQmovun('s16', load[0], load[0])
197    elif len(load) is 2:
198      emitter.EmitVQmovn2('s32', load[0], load[0], load[1])
199      emitter.EmitVQmovun('s16', load[0], load[0])
200    elif len(load) is 3:
201      emitter.EmitVQmovn2('s32', load[0], load[0], load[1])
202      emitter.EmitVQmovn('s32', load[2], load[2])
203      emitter.EmitVQmovun2('s16', load[0], load[0], load[2])
204    elif len(load) is 4:
205      emitter.EmitVQmovn2('s32', load[0], load[0], load[1])
206      emitter.EmitVQmovn2('s32', load[2], load[2], load[3])
207      emitter.EmitVQmovun2('s16', load[0], load[0], load[2])
208    else:
209      assert False
210
211    emitter.EmitNewline()
212    emitter.EmitVStoreAE(8, elements, load, output_address, None)
213    emitter.EmitPld(output_address)
214    registers.FreeRegisters(load)
215
216
217class RequantizeTransformation(object):
218  """."""
219
220  def Check(self, in_type, out_type, kernel_size, leftovers):
221    assert in_type is 'int32_t'
222    assert out_type is 'uint8_t'
223    assert kernel_size is 16
224    assert leftovers < 16
225
226  def Prepare(self, emitter, registers, unused_kernel_size):
227    """Duplicate quantization parameters to vector registers."""
228    emitter.EmitNewline()
229    emitter.EmitComment('Requantize::Prepare')
230
231    self.range_min_delta = _DuplicateGeneralRegister(
232        32, emitter, registers,
233        registers.MapParameter('input_range_min', 'params.input_range_min'), 4)
234    self.output_range_min = _DuplicateGeneralRegister(
235        32, emitter, registers,
236        registers.MapParameter('output_range_min', 'params.output_range_min'),
237        4)
238    self.input_range_offset = _DuplicateGeneralRegister(
239        32, emitter, registers,
240        registers.MapParameter('input_range_offset',
241                               'params.input_range_offset'), 4)
242    self.input_range_scale = _DuplicateGeneralRegister(
243        32, emitter, registers,
244        registers.MapParameter('input_range_scale', 'params.input_range_scale'),
245        4)
246    self.one_over_output_range_scale = _DuplicateGeneralRegister(
247        32, emitter, registers,
248        registers.MapParameter('one_over_output_range_scale',
249                               'params.one_over_output_range_scale'), 4)
250    emitter.EmitVSub('f32', self.range_min_delta, self.range_min_delta,
251                     self.output_range_min)
252
253  def Transform(self, emitter, registers, input_address, elements,
254                output_address):
255    """Emit requantization inner loop code."""
256    emitter.EmitNewline()
257    emitter.EmitComment('Requantize::Transform')
258    register_count = (elements + 3) / 4
259    load = [registers.QuadRegister() for unused_i in range(register_count)]
260    emitter.EmitVLoadAE(32, elements, load, input_address, None)
261    emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(64))
262
263    for register in load:
264      emitter.EmitVCvt('f32', 's32', register, register)
265
266    for register in load:
267      emitter.EmitVSub('f32', register, register, self.input_range_offset)
268
269    for register in load:
270      emitter.EmitVMul('f32', register, register, self.input_range_scale)
271
272    for register in load:
273      emitter.EmitVAdd('f32', register, register, self.range_min_delta)
274
275    for register in load:
276      emitter.EmitVMul('f32', register, register,
277                       self.one_over_output_range_scale)
278
279    for register in load:
280      emitter.EmitVCvt('s32', 'f32', register, register)
281
282    if len(load) is 1:
283      emitter.EmitVQmovn('s32', load[0], load[0])
284      emitter.EmitVQmovun('s16', load[0], load[0])
285    elif len(load) is 2:
286      emitter.EmitVQmovn2('s32', load[0], load[0], load[1])
287      emitter.EmitVQmovun('s16', load[0], load[0])
288    elif len(load) is 3:
289      emitter.EmitVQmovn2('s32', load[0], load[0], load[1])
290      emitter.EmitVQmovn('s32', load[2], load[2])
291      emitter.EmitVQmovun2('s16', load[0], load[0], load[2])
292    elif len(load) is 4:
293      emitter.EmitVQmovn2('s32', load[0], load[0], load[1])
294      emitter.EmitVQmovn2('s32', load[2], load[2], load[3])
295      emitter.EmitVQmovun2('s16', load[0], load[0], load[2])
296    else:
297      assert False
298
299    emitter.EmitNewline()
300    emitter.EmitVStoreAE(8, elements, load, output_address, None)
301    emitter.EmitPld(output_address)
302    registers.FreeRegisters(load)
303
304
305class BaseTransform(common.Transform1DKernelGenerator):
306  """."""
307
308  def __init__(self, cc_emitter, kernel_name, asm_emitter, transformation):
309    common.Transform1DKernelGenerator.__init__(self, cc_emitter, kernel_name)
310    self.asm_emitter = asm_emitter
311    self.transformation = transformation
312
313  def EmitTransform(self, in_type, out_type, kernel_size, leftovers):
314    """."""
315    self.transformation.Check(in_type, out_type, kernel_size, leftovers)
316
317    registers = self.asm_emitter.CreateRegisters()
318
319    self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
320
321    self.asm_emitter.PushIndent(self.emitter.indent)
322    self.asm_emitter.EmitAsmBegin()
323
324    count = registers.MapOutputParameter('count', 'params_count_copy')
325    input_address = registers.MapOutputParameter('input')
326    output_address = registers.MapOutputParameter('output')
327
328    self.transformation.Prepare(self.asm_emitter, registers, kernel_size)
329
330    if leftovers:
331      self.asm_emitter.EmitNewline()
332      self.asm_emitter.EmitComment('Reduce count by leftovers.')
333      self.asm_emitter.EmitSubs(count, count,
334                                self.asm_emitter.ImmediateConstant(leftovers))
335      self.asm_emitter.EmitBeqFront(2)
336
337    self.asm_emitter.EmitNewline()
338    self.asm_emitter.EmitNumericalLabel(1)
339    self.asm_emitter.EmitSubs(count, count,
340                              self.asm_emitter.ImmediateConstant(kernel_size))
341
342    self.transformation.Transform(self.asm_emitter, registers, input_address,
343                                  kernel_size, output_address)
344
345    self.asm_emitter.EmitNewline()
346    self.asm_emitter.EmitBneBack(1)
347
348    if leftovers:
349      self.asm_emitter.EmitNumericalLabel(2)
350      self.asm_emitter.EmitNewline()
351      self.asm_emitter.EmitComment('Handle leftovers.')
352      self.transformation.Transform(self.asm_emitter, registers, input_address,
353                                    leftovers, output_address)
354
355    self.asm_emitter.EmitAsmEnd(registers)
356    self.asm_emitter.PopIndent(len(self.emitter.indent))
357
358
359class Requantize(BaseTransform):
360  """."""
361
362  def __init__(self, cc_emitter, asm_emitter):
363    BaseTransform.__init__(self, cc_emitter, 'Requantize', asm_emitter,
364                           RequantizeTransformation())
365
366
367class Quantize(BaseTransform):
368  """."""
369
370  def __init__(self, cc_emitter, asm_emitter):
371    BaseTransform.__init__(self, cc_emitter, 'Quantize', asm_emitter,
372                           QuantizeTransformation())
373
374
375class Dequantize(BaseTransform):
376  """."""
377
378  def __init__(self, cc_emitter, asm_emitter):
379    BaseTransform.__init__(self, cc_emitter, 'Dequantize', asm_emitter,
380                           DequantizeTransformation())
381
382
383class MinMax(BaseTransform):
384  """."""
385
386  def __init__(self, numerical_type, cc_emitter, asm_emitter):
387    BaseTransform.__init__(self, cc_emitter, 'MinMax<%s>' % numerical_type,
388                           asm_emitter, MinMaxTransformation())
389
390
391class BiasAdd(common.Transform1DKernelGenerator):
392  """."""
393
394  def __init__(self, bias_type, cc_emitter, asm_emitter):
395    common.Transform1DKernelGenerator.__init__(self, cc_emitter,
396                                               'BiasAdd<%s>' % bias_type)
397    self.asm_emitter = asm_emitter
398
399  def EmitTransform(self, in_type, out_type, kernel_size, leftovers):
400    """."""
401    assert in_type is 'uint8_t'
402    assert out_type is 'int32_t'
403    assert kernel_size is 16
404    assert leftovers < 16
405
406    registers = self.asm_emitter.CreateRegisters()
407
408    self.emitter.EmitDeclare('int', 'params_rows_copy', 'params.rows')
409
410    self.asm_emitter.PushIndent(self.emitter.indent)
411    self.asm_emitter.EmitAsmBegin()
412
413    self._Prepare(self.asm_emitter, registers)
414
415    rows = registers.MapParameter('rows', 'params_rows_copy')
416
417    self.asm_emitter.EmitNumericalLabel(1)
418
419    self._ProcessRow(self.asm_emitter, registers, kernel_size, leftovers)
420
421    self.asm_emitter.EmitSubs(rows, rows, self.asm_emitter.ImmediateConstant(1))
422    self.asm_emitter.EmitBneBack(1)
423
424    self.asm_emitter.EmitAsmEnd(registers)
425    self.asm_emitter.PopIndent(len(self.emitter.indent))
426
427  def _Prepare(self, emitter, registers):
428    self.input_range_min = _DuplicateGeneralMemoryRegister(
429        32, emitter, registers,
430        registers.MapMemoryParameter('input_range_min',
431                                     'params.input_range_min'), 8)
432    self.input_range_scale = _DuplicateGeneralMemoryRegister(
433        32, emitter, registers,
434        registers.MapMemoryParameter('input_range_scale',
435                                     'params.input_range_scale'), 8)
436    self.bias_range_min = _DuplicateGeneralMemoryRegister(
437        32, emitter, registers,
438        registers.MapMemoryParameter('bias_range_min', 'params.bias_range_min'),
439        8)
440    self.bias_range_scale = _DuplicateGeneralMemoryRegister(
441        32, emitter, registers,
442        registers.MapMemoryParameter('bias_range_scale',
443                                     'params.bias_range_scale'), 8)
444    self.output_range_min = _DuplicateGeneralMemoryRegister(
445        32, emitter, registers,
446        registers.MapMemoryParameter('output_range_min',
447                                     'params.output_range_min'), 8)
448    self.one_over_output_range_scale = _DuplicateGeneralMemoryRegister(
449        32, emitter, registers,
450        registers.MapMemoryParameter('one_over_output_range_scale',
451                                     'params.one_over_output_range_scale'), 8)
452    self.output_range_offset = _DuplicateGeneralMemoryRegister(
453        32, emitter, registers,
454        registers.MapMemoryParameter('output_range_offset',
455                                     'params.output_range_offset'), 8)
456
457  def _ProcessRow(self, emitter, registers, kernel_size, leftovers):
458    const_count = registers.MapParameter('count', 'params.count')
459    const_bias = registers.MapParameter('bias', 'params.bias')
460
461    count = registers.GeneralRegister()
462    bias = registers.GeneralRegister()
463
464    input_address = registers.MapOutputParameter('input')
465    output_address = registers.MapOutputParameter('output')
466
467    emitter.EmitMov(count, const_count)
468    emitter.EmitMov(bias, const_bias)
469
470    if leftovers:
471      emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers))
472      emitter.EmitBeqFront(3)
473
474    emitter.EmitNumericalLabel(2)
475    emitter.EmitSubs(count, count, emitter.ImmediateConstant(kernel_size))
476
477    self._BiasAdd(emitter, registers, kernel_size, input_address, bias,
478                  output_address)
479
480    emitter.EmitBneBack(2)
481
482    if leftovers:
483      emitter.EmitNumericalLabel(3)
484      self._BiasAdd(emitter, registers, leftovers, input_address, bias,
485                    output_address)
486
487  def _BiasAdd(self, emitter, registers, elements, input_address, bias,
488               output_address):
489    emitter.EmitNewline()
490    emitter.EmitComment('BiasAdd::Transform')
491    register_count = (elements + 3) / 4
492
493    load_input = [
494        registers.QuadRegister() for unused_i in range(register_count)
495    ]
496    load_bias = [registers.QuadRegister() for unused_i in range(register_count)]
497
498    emitter.EmitVLoadAE(8, elements, load_input, input_address, None)
499    emitter.EmitVLoadAE(8, elements, load_bias, bias, None)
500    emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(32))
501
502    if len(load_input) is 1:
503      emitter.EmitVMovl('u8', load_input[0], load_input[0])
504      emitter.EmitVMovl('u8', load_bias[0], load_bias[0])
505      emitter.EmitVMovl('s16', load_input[0], load_input[0])
506      emitter.EmitVMovl('s16', load_bias[0], load_bias[0])
507    elif len(load_input) is 2:
508      emitter.EmitVMovl('u8', load_input[0], load_input[0])
509      emitter.EmitVMovl('u8', load_bias[0], load_bias[0])
510      emitter.EmitVMovl2('s16', load_input[0], load_input[1], load_input[0])
511      emitter.EmitVMovl2('s16', load_bias[0], load_bias[1], load_bias[0])
512    elif len(load_input) is 3:
513      emitter.EmitVMovl2('u8', load_input[0], load_input[1], load_input[0])
514      emitter.EmitVMovl2('u8', load_bias[0], load_bias[1], load_bias[0])
515      emitter.EmitVMovl('s16', load_input[2], load_input[1])
516      emitter.EmitVMovl('s16', load_bias[2], load_bias[1])
517      emitter.EmitVMovl2('s16', load_input[0], load_input[1], load_input[0])
518      emitter.EmitVMovl2('s16', load_bias[0], load_bias[1], load_bias[0])
519    elif len(load_input) is 4:
520      emitter.EmitVMovl2('u8', load_input[0], load_input[1], load_input[0])
521      emitter.EmitVMovl2('u8', load_bias[0], load_bias[1], load_bias[0])
522      emitter.EmitVMovl2('s16', load_input[2], load_input[3], load_input[1])
523      emitter.EmitVMovl2('s16', load_bias[2], load_bias[3], load_bias[1])
524      emitter.EmitVMovl2('s16', load_input[0], load_input[1], load_input[0])
525      emitter.EmitVMovl2('s16', load_bias[0], load_bias[1], load_bias[0])
526    else:
527      assert False
528
529    for register in load_input + load_bias:
530      emitter.EmitVCvt('f32', 's32', register, register)
531
532    for register in load_input:
533      emitter.EmitVMul('f32', register, register, self.input_range_scale)
534
535    for register in load_bias:
536      emitter.EmitVMul('f32', register, register, self.bias_range_scale)
537
538    for register in load_input:
539      emitter.EmitVAdd('f32', register, register, self.input_range_min)
540
541    for register in load_bias:
542      emitter.EmitVAdd('f32', register, register, self.bias_range_min)
543
544    for (register_1, register_2) in zip(load_input, load_bias):
545      emitter.EmitVAdd('f32', register_1, register_1, register_2)
546
547    for register in load_input:
548      emitter.EmitVSub('f32', register, register, self.output_range_min)
549
550    for register in load_input:
551      emitter.EmitVMul('f32', register, register,
552                       self.one_over_output_range_scale)
553
554    for register in load_input:
555      emitter.EmitVAdd('f32', register, register, self.output_range_offset)
556
557    for register in load_input:
558      emitter.EmitVCvt('s32', 'f32', register, register)
559
560    emitter.EmitNewline()
561    emitter.EmitVStoreAE(32, elements, load_input, output_address, None)
562    emitter.EmitPld(output_address)
563    registers.FreeRegisters(load_input + load_bias)
564
565
566def GenerateKernels(cc_emitter, asm_emitter, shapes):
567  """Generate the quantization/dequantization/requantization kernels."""
568  requantize = Requantize(cc_emitter, asm_emitter)
569  quantize = Quantize(cc_emitter, asm_emitter)
570  dequantize = Dequantize(cc_emitter, asm_emitter)
571  minmax = MinMax('uint8_t', cc_emitter, asm_emitter)
572  biasadd = BiasAdd('uint8_t', cc_emitter, asm_emitter)
573
574  for shape in shapes:
575    requantize.SpecializeTransform1DKernel('int32_t', 'uint8_t', shape[0],
576                                           shape[1])
577
578  for shape in shapes:
579    quantize.SpecializeTransform1DKernel('float', 'uint8_t', shape[0], shape[1])
580
581  for shape in shapes:
582    dequantize.SpecializeTransform1DKernel('uint8_t', 'float', shape[0],
583                                           shape[1])
584
585  for shape in shapes:
586    minmax.SpecializeTransform1DKernel('uint8_t', 'uint8_t', shape[0], shape[1])
587
588  for shape in shapes:
589    biasadd.SpecializeTransform1DKernel('uint8_t', 'int32_t', shape[0],
590                                        shape[1])
591