xref: /aosp_15_r20/external/gemmlowp/meta/generators/streams_common.py (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2*5f39d1b3SJooyung Han#
3*5f39d1b3SJooyung Han# Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han# you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han# You may obtain a copy of the License at
6*5f39d1b3SJooyung Han#
7*5f39d1b3SJooyung Han#    http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han#
9*5f39d1b3SJooyung Han# Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han# distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han# See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han# limitations under the License.
14*5f39d1b3SJooyung Han"""."""
15*5f39d1b3SJooyung Han
16*5f39d1b3SJooyung Hanimport common
17*5f39d1b3SJooyung Han
18*5f39d1b3SJooyung Han
19*5f39d1b3SJooyung Handef _AlignForLanes(lanes_count):
20*5f39d1b3SJooyung Han  if lanes_count is 8 or lanes_count is 4:
21*5f39d1b3SJooyung Han    return 256
22*5f39d1b3SJooyung Han  elif lanes_count is 6 or lanes_count is 2:
23*5f39d1b3SJooyung Han    return 128
24*5f39d1b3SJooyung Han  else:
25*5f39d1b3SJooyung Han    return 64
26*5f39d1b3SJooyung Han
27*5f39d1b3SJooyung Han
28*5f39d1b3SJooyung Handef _AlignForSums(lanes_count):
29*5f39d1b3SJooyung Han  if lanes_count is 8:
30*5f39d1b3SJooyung Han    return 256
31*5f39d1b3SJooyung Han  elif lanes_count in [2, 4, 6]:
32*5f39d1b3SJooyung Han    return 128
33*5f39d1b3SJooyung Han  else:
34*5f39d1b3SJooyung Han    return 64
35*5f39d1b3SJooyung Han
36*5f39d1b3SJooyung Han
37*5f39d1b3SJooyung Handef _GenerateInputs(emitter, registers, lanes_count, input_address, stride):
38*5f39d1b3SJooyung Han  """."""
39*5f39d1b3SJooyung Han  inputs = []
40*5f39d1b3SJooyung Han  last_address_register = input_address
41*5f39d1b3SJooyung Han  for i in range(lanes_count):
42*5f39d1b3SJooyung Han    if not i:
43*5f39d1b3SJooyung Han      inputs.append(input_address)
44*5f39d1b3SJooyung Han    else:
45*5f39d1b3SJooyung Han      address_register = registers.GeneralRegister()
46*5f39d1b3SJooyung Han      inputs.append(address_register)
47*5f39d1b3SJooyung Han      emitter.EmitAdd(address_register, last_address_register, stride)
48*5f39d1b3SJooyung Han      last_address_register = address_register
49*5f39d1b3SJooyung Han  return inputs
50*5f39d1b3SJooyung Han
51*5f39d1b3SJooyung Han
52*5f39d1b3SJooyung Handef _GenerateClear(emitter, clear_type, block):
53*5f39d1b3SJooyung Han  for row in block:
54*5f39d1b3SJooyung Han    emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0))
55*5f39d1b3SJooyung Han
56*5f39d1b3SJooyung Han
57*5f39d1b3SJooyung Handef _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count,
58*5f39d1b3SJooyung Han                                aggregators, inputs, output):
59*5f39d1b3SJooyung Han  """Emit inner loop code for reading N lanes and interweaving them."""
60*5f39d1b3SJooyung Han  emitter.EmitNewline()
61*5f39d1b3SJooyung Han  emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count,
62*5f39d1b3SJooyung Han                                                        elements_count))
63*5f39d1b3SJooyung Han
64*5f39d1b3SJooyung Han  block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
65*5f39d1b3SJooyung Han
66*5f39d1b3SJooyung Han  if elements_count is not 8:
67*5f39d1b3SJooyung Han    _GenerateClear(emitter, 'i8', block)
68*5f39d1b3SJooyung Han
69*5f39d1b3SJooyung Han  for (row, input_address) in zip(block, inputs):
70*5f39d1b3SJooyung Han    emitter.EmitVLoadE(8, elements_count, row, input_address, None)
71*5f39d1b3SJooyung Han
72*5f39d1b3SJooyung Han  for (aggregator, row) in zip(aggregators, block):
73*5f39d1b3SJooyung Han    emitter.EmitVAddw('u8', aggregator, aggregator, row)
74*5f39d1b3SJooyung Han
75*5f39d1b3SJooyung Han  emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
76*5f39d1b3SJooyung Han                       _AlignForLanes(lanes_count))
77*5f39d1b3SJooyung Han
78*5f39d1b3SJooyung Han  registers.FreeRegisters(block)
79*5f39d1b3SJooyung Han
80*5f39d1b3SJooyung Han
81*5f39d1b3SJooyung Handef _LoadMemoryParameter(emitter, registers, name, source):
82*5f39d1b3SJooyung Han  register = registers.GeneralRegister()
83*5f39d1b3SJooyung Han  emitter.EmitLdr(register, registers.MapMemoryParameter(name, source))
84*5f39d1b3SJooyung Han  return register
85*5f39d1b3SJooyung Han
86*5f39d1b3SJooyung Han
87*5f39d1b3SJooyung Handef _GenerateAggregatorReductionLowRegisters(emitter, registers,
88*5f39d1b3SJooyung Han                                             aggregators, output_address):
89*5f39d1b3SJooyung Han  emitter.EmitNewline()
90*5f39d1b3SJooyung Han  emitter.EmitComment('Aggregator Reduction.')
91*5f39d1b3SJooyung Han  _GenerateAggregatorReduction(
92*5f39d1b3SJooyung Han      emitter, registers, aggregators, output_address,
93*5f39d1b3SJooyung Han      _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset',
94*5f39d1b3SJooyung Han                           'params.multiplicative_sum_offset'),
95*5f39d1b3SJooyung Han      _LoadMemoryParameter(emitter, registers, 'additive_sum_offset',
96*5f39d1b3SJooyung Han                           'params.additive_sum_offset'))
97*5f39d1b3SJooyung Han
98*5f39d1b3SJooyung Han
99*5f39d1b3SJooyung Handef _GenerateAggregatorReductionHighRegisters(emitter, registers,
100*5f39d1b3SJooyung Han                                              aggregators, output_address):
101*5f39d1b3SJooyung Han  emitter.EmitNewline()
102*5f39d1b3SJooyung Han  emitter.EmitComment('Aggregator Reduction.')
103*5f39d1b3SJooyung Han  _GenerateAggregatorReduction(
104*5f39d1b3SJooyung Han      emitter, registers, aggregators, output_address,
105*5f39d1b3SJooyung Han      registers.MapParameter('multiplicative_sum_offset',
106*5f39d1b3SJooyung Han                             'params.multiplicative_sum_offset'),
107*5f39d1b3SJooyung Han      registers.MapParameter('additive_sum_offset',
108*5f39d1b3SJooyung Han                             'params.additive_sum_offset'))
109*5f39d1b3SJooyung Han
110*5f39d1b3SJooyung Han
111*5f39d1b3SJooyung Handef _GenerateAggregatorReduction(emitter, registers, aggregators,
112*5f39d1b3SJooyung Han                                 output_address, multiplicative_sum_offset,
113*5f39d1b3SJooyung Han                                 additive_sum_offset):
114*5f39d1b3SJooyung Han  """Reduce 4 lane sum aggregators to 1 value and store the sums."""
115*5f39d1b3SJooyung Han  multiplier = registers.DoubleRegister()
116*5f39d1b3SJooyung Han  emitter.EmitVMov('32',
117*5f39d1b3SJooyung Han                   emitter.Lane(32, multiplier, 0), multiplicative_sum_offset)
118*5f39d1b3SJooyung Han
119*5f39d1b3SJooyung Han  offset = registers.QuadRegister()
120*5f39d1b3SJooyung Han  emitter.EmitVDup('32', offset, additive_sum_offset)
121*5f39d1b3SJooyung Han
122*5f39d1b3SJooyung Han  for aggregator in aggregators:
123*5f39d1b3SJooyung Han    emitter.EmitVPaddl('u16', aggregator, aggregator)
124*5f39d1b3SJooyung Han
125*5f39d1b3SJooyung Han  reduced_count = (len(aggregators) + 3) / 4
126*5f39d1b3SJooyung Han  reduced = aggregators[:reduced_count]
127*5f39d1b3SJooyung Han
128*5f39d1b3SJooyung Han  emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
129*5f39d1b3SJooyung Han
130*5f39d1b3SJooyung Han  for temp in reduced:
131*5f39d1b3SJooyung Han    emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0))
132*5f39d1b3SJooyung Han
133*5f39d1b3SJooyung Han  for temp in reduced:
134*5f39d1b3SJooyung Han    emitter.EmitVAdd('i32', temp, temp, offset)
135*5f39d1b3SJooyung Han
136*5f39d1b3SJooyung Han  emitter.EmitVStoreA(1, 32, reduced,
137*5f39d1b3SJooyung Han                      emitter.Dereference(output_address,
138*5f39d1b3SJooyung Han                                          _AlignForSums(len(aggregators))))
139*5f39d1b3SJooyung Han
140*5f39d1b3SJooyung Han
141*5f39d1b3SJooyung Hanclass RowMajorWithSumUInt8x8(common.StreamGenerator):
142*5f39d1b3SJooyung Han  """."""
143*5f39d1b3SJooyung Han
144*5f39d1b3SJooyung Han  def __init__(self, emitter, asm_emitter):
145*5f39d1b3SJooyung Han    common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum')
146*5f39d1b3SJooyung Han    self.asm_emitter = asm_emitter
147*5f39d1b3SJooyung Han
148*5f39d1b3SJooyung Han  def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
149*5f39d1b3SJooyung Han    assert pack_size is 8
150*5f39d1b3SJooyung Han    assert in_type is 'uint8_t'
151*5f39d1b3SJooyung Han
152*5f39d1b3SJooyung Han    registers = self.asm_emitter.CreateRegisters()
153*5f39d1b3SJooyung Han
154*5f39d1b3SJooyung Han    self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
155*5f39d1b3SJooyung Han
156*5f39d1b3SJooyung Han    self.asm_emitter.PushIndent(self.emitter.indent)
157*5f39d1b3SJooyung Han    self.asm_emitter.EmitAsmBegin()
158*5f39d1b3SJooyung Han
159*5f39d1b3SJooyung Han    count = registers.MapOutputParameter('count', 'params_count_copy')
160*5f39d1b3SJooyung Han    output = registers.MapOutputParameter('out')
161*5f39d1b3SJooyung Han    inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count,
162*5f39d1b3SJooyung Han                             registers.MapOutputParameter('in'),
163*5f39d1b3SJooyung Han                             registers.MapParameter('stride', 'params.stride'))
164*5f39d1b3SJooyung Han    aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
165*5f39d1b3SJooyung Han
166*5f39d1b3SJooyung Han    _GenerateClear(self.asm_emitter, 'i16', aggregators)
167*5f39d1b3SJooyung Han
168*5f39d1b3SJooyung Han    if leftovers:
169*5f39d1b3SJooyung Han      self.asm_emitter.EmitNewline()
170*5f39d1b3SJooyung Han      self.asm_emitter.EmitComment('Reduce count by leftovers.')
171*5f39d1b3SJooyung Han      self.asm_emitter.EmitSubs(count, count,
172*5f39d1b3SJooyung Han                                self.asm_emitter.ImmediateConstant(leftovers))
173*5f39d1b3SJooyung Han      self.asm_emitter.EmitBeqFront(2)
174*5f39d1b3SJooyung Han
175*5f39d1b3SJooyung Han    self.asm_emitter.EmitNewline()
176*5f39d1b3SJooyung Han    self.asm_emitter.EmitNumericalLabel(1)
177*5f39d1b3SJooyung Han    self.asm_emitter.EmitSubs(count, count,
178*5f39d1b3SJooyung Han                              self.asm_emitter.ImmediateConstant(8))
179*5f39d1b3SJooyung Han
180*5f39d1b3SJooyung Han    _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
181*5f39d1b3SJooyung Han                                aggregators, inputs, output)
182*5f39d1b3SJooyung Han
183*5f39d1b3SJooyung Han    self.asm_emitter.EmitNewline()
184*5f39d1b3SJooyung Han    self.asm_emitter.EmitBneBack(1)
185*5f39d1b3SJooyung Han
186*5f39d1b3SJooyung Han    if leftovers:
187*5f39d1b3SJooyung Han      self.asm_emitter.EmitNewline()
188*5f39d1b3SJooyung Han      self.asm_emitter.EmitNumericalLabel(2)
189*5f39d1b3SJooyung Han      _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count,
190*5f39d1b3SJooyung Han                                  leftovers, aggregators, inputs, output)
191*5f39d1b3SJooyung Han
192*5f39d1b3SJooyung Han    registers.FreeRegisters(inputs)
193*5f39d1b3SJooyung Han
194*5f39d1b3SJooyung Han    if len(inputs) <= 6:
195*5f39d1b3SJooyung Han      _GenerateAggregatorReductionHighRegisters(
196*5f39d1b3SJooyung Han          self.asm_emitter, registers, aggregators, output)
197*5f39d1b3SJooyung Han    else:
198*5f39d1b3SJooyung Han      _GenerateAggregatorReductionLowRegisters(
199*5f39d1b3SJooyung Han          self.asm_emitter, registers, aggregators, output)
200*5f39d1b3SJooyung Han
201*5f39d1b3SJooyung Han    self.asm_emitter.EmitAsmEnd(registers)
202*5f39d1b3SJooyung Han    self.asm_emitter.PopIndent(len(self.emitter.indent))
203*5f39d1b3SJooyung Han
204*5f39d1b3SJooyung Han
205*5f39d1b3SJooyung Handef _GenerateColLoadAggregateStore(emitter, registers, lanes_count,
206*5f39d1b3SJooyung Han                                   elements_count, aggregators, input_address,
207*5f39d1b3SJooyung Han                                   stride, output):
208*5f39d1b3SJooyung Han  """Emit inner loop code for reading N col lanes and interweaving them."""
209*5f39d1b3SJooyung Han  emitter.EmitNewline()
210*5f39d1b3SJooyung Han  emitter.EmitComment('Load Aggregate Store - column major %dx%d' %
211*5f39d1b3SJooyung Han                      (lanes_count, elements_count))
212*5f39d1b3SJooyung Han
213*5f39d1b3SJooyung Han  block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
214*5f39d1b3SJooyung Han
215*5f39d1b3SJooyung Han  if elements_count is not 8:
216*5f39d1b3SJooyung Han    _GenerateClear(emitter, 'i8', block)
217*5f39d1b3SJooyung Han
218*5f39d1b3SJooyung Han  block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count,
219*5f39d1b3SJooyung Han                                   block, input_address, stride)
220*5f39d1b3SJooyung Han
221*5f39d1b3SJooyung Han  for (aggregator, row) in zip(aggregators, block):
222*5f39d1b3SJooyung Han    emitter.EmitVAddw('u8', aggregator, aggregator, row)
223*5f39d1b3SJooyung Han
224*5f39d1b3SJooyung Han  emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
225*5f39d1b3SJooyung Han                       _AlignForLanes(lanes_count))
226*5f39d1b3SJooyung Han
227*5f39d1b3SJooyung Han  registers.FreeRegisters(block)
228*5f39d1b3SJooyung Han
229*5f39d1b3SJooyung Han
230*5f39d1b3SJooyung Hanclass ColumnMajorWithSumUInt8x8(common.StreamGenerator):
231*5f39d1b3SJooyung Han  """."""
232*5f39d1b3SJooyung Han
233*5f39d1b3SJooyung Han  def __init__(self, emitter, asm_emitter):
234*5f39d1b3SJooyung Han    common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum')
235*5f39d1b3SJooyung Han    self.asm_emitter = asm_emitter
236*5f39d1b3SJooyung Han
237*5f39d1b3SJooyung Han  def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
238*5f39d1b3SJooyung Han    assert pack_size is 8
239*5f39d1b3SJooyung Han    assert in_type is 'uint8_t'
240*5f39d1b3SJooyung Han
241*5f39d1b3SJooyung Han    registers = self.asm_emitter.CreateRegisters()
242*5f39d1b3SJooyung Han
243*5f39d1b3SJooyung Han    self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
244*5f39d1b3SJooyung Han    self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride')
245*5f39d1b3SJooyung Han
246*5f39d1b3SJooyung Han    self.asm_emitter.PushIndent(self.emitter.indent)
247*5f39d1b3SJooyung Han    self.asm_emitter.EmitAsmBegin()
248*5f39d1b3SJooyung Han
249*5f39d1b3SJooyung Han    count = registers.MapOutputParameter('count', 'params_count_copy')
250*5f39d1b3SJooyung Han    input_address = registers.MapOutputParameter('in')
251*5f39d1b3SJooyung Han    output_address = registers.MapOutputParameter('out')
252*5f39d1b3SJooyung Han    aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
253*5f39d1b3SJooyung Han    stride = registers.MapOutputParameter('stride', 'params_stride_copy')
254*5f39d1b3SJooyung Han
255*5f39d1b3SJooyung Han    self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride)
256*5f39d1b3SJooyung Han
257*5f39d1b3SJooyung Han    _GenerateClear(self.asm_emitter, 'i16', aggregators)
258*5f39d1b3SJooyung Han
259*5f39d1b3SJooyung Han    if leftovers:
260*5f39d1b3SJooyung Han      self.asm_emitter.EmitNewline()
261*5f39d1b3SJooyung Han      self.asm_emitter.EmitComment('Reduce count by leftovers.')
262*5f39d1b3SJooyung Han      self.asm_emitter.EmitSubs(count, count,
263*5f39d1b3SJooyung Han                                self.asm_emitter.ImmediateConstant(leftovers))
264*5f39d1b3SJooyung Han      self.asm_emitter.EmitBeqFront(2)
265*5f39d1b3SJooyung Han
266*5f39d1b3SJooyung Han    self.asm_emitter.EmitNewline()
267*5f39d1b3SJooyung Han    self.asm_emitter.EmitNumericalLabel(1)
268*5f39d1b3SJooyung Han    self.asm_emitter.EmitSubs(count, count,
269*5f39d1b3SJooyung Han                              self.asm_emitter.ImmediateConstant(8))
270*5f39d1b3SJooyung Han
271*5f39d1b3SJooyung Han    _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
272*5f39d1b3SJooyung Han                                   aggregators, input_address, stride,
273*5f39d1b3SJooyung Han                                   output_address)
274*5f39d1b3SJooyung Han
275*5f39d1b3SJooyung Han    self.asm_emitter.EmitNewline()
276*5f39d1b3SJooyung Han    self.asm_emitter.EmitBneBack(1)
277*5f39d1b3SJooyung Han
278*5f39d1b3SJooyung Han    if leftovers:
279*5f39d1b3SJooyung Han      self.asm_emitter.EmitNewline()
280*5f39d1b3SJooyung Han      self.asm_emitter.EmitNumericalLabel(2)
281*5f39d1b3SJooyung Han      _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count,
282*5f39d1b3SJooyung Han                                     leftovers, aggregators, input_address,
283*5f39d1b3SJooyung Han                                     stride, output_address)
284*5f39d1b3SJooyung Han
285*5f39d1b3SJooyung Han
286*5f39d1b3SJooyung Han    _GenerateAggregatorReductionHighRegisters(
287*5f39d1b3SJooyung Han        self.asm_emitter, registers, aggregators, output_address)
288*5f39d1b3SJooyung Han
289*5f39d1b3SJooyung Han    self.asm_emitter.EmitAsmEnd(registers)
290*5f39d1b3SJooyung Han    self.asm_emitter.PopIndent(len(self.emitter.indent))
291*5f39d1b3SJooyung Han
292*5f39d1b3SJooyung Han
293*5f39d1b3SJooyung Handef GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count):
294*5f39d1b3SJooyung Han  row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter)
295*5f39d1b3SJooyung Han  column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter)
296*5f39d1b3SJooyung Han
297*5f39d1b3SJooyung Han  for lanes_count in range(1, 1 + lanes_count):
298*5f39d1b3SJooyung Han    for leftovers in range(8):
299*5f39d1b3SJooyung Han      row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers)
300*5f39d1b3SJooyung Han
301*5f39d1b3SJooyung Han  for lanes_count in range(1, 1 + lanes_count):
302*5f39d1b3SJooyung Han    for leftovers in range(8):
303*5f39d1b3SJooyung Han      column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8,
304*5f39d1b3SJooyung Han                                             leftovers)
305