1*5f39d1b3SJooyung Han# Copyright 2016 The Gemmlowp Authors. All rights reserved. 2*5f39d1b3SJooyung Han# 3*5f39d1b3SJooyung Han# Licensed under the Apache License, Version 2.0 (the "License"); 4*5f39d1b3SJooyung Han# you may not use this file except in compliance with the License. 5*5f39d1b3SJooyung Han# You may obtain a copy of the License at 6*5f39d1b3SJooyung Han# 7*5f39d1b3SJooyung Han# http://www.apache.org/licenses/LICENSE-2.0 8*5f39d1b3SJooyung Han# 9*5f39d1b3SJooyung Han# Unless required by applicable law or agreed to in writing, software 10*5f39d1b3SJooyung Han# distributed under the License is distributed on an "AS IS" BASIS, 11*5f39d1b3SJooyung Han# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*5f39d1b3SJooyung Han# See the License for the specific language governing permissions and 13*5f39d1b3SJooyung Han# limitations under the License. 14*5f39d1b3SJooyung Han""".""" 15*5f39d1b3SJooyung Han 16*5f39d1b3SJooyung Hanimport common 17*5f39d1b3SJooyung Han 18*5f39d1b3SJooyung Han 19*5f39d1b3SJooyung Handef _AlignForLanes(lanes_count): 20*5f39d1b3SJooyung Han if lanes_count is 8 or lanes_count is 4: 21*5f39d1b3SJooyung Han return 256 22*5f39d1b3SJooyung Han elif lanes_count is 6 or lanes_count is 2: 23*5f39d1b3SJooyung Han return 128 24*5f39d1b3SJooyung Han else: 25*5f39d1b3SJooyung Han return 64 26*5f39d1b3SJooyung Han 27*5f39d1b3SJooyung Han 28*5f39d1b3SJooyung Handef _AlignForSums(lanes_count): 29*5f39d1b3SJooyung Han if lanes_count is 8: 30*5f39d1b3SJooyung Han return 256 31*5f39d1b3SJooyung Han elif lanes_count in [2, 4, 6]: 32*5f39d1b3SJooyung Han return 128 33*5f39d1b3SJooyung Han else: 34*5f39d1b3SJooyung Han return 64 35*5f39d1b3SJooyung Han 36*5f39d1b3SJooyung Han 37*5f39d1b3SJooyung Handef _GenerateInputs(emitter, registers, lanes_count, input_address, stride): 38*5f39d1b3SJooyung Han """.""" 39*5f39d1b3SJooyung Han inputs = [] 40*5f39d1b3SJooyung Han last_address_register = input_address 41*5f39d1b3SJooyung Han for i in range(lanes_count): 42*5f39d1b3SJooyung Han if not i: 43*5f39d1b3SJooyung Han inputs.append(input_address) 44*5f39d1b3SJooyung Han else: 45*5f39d1b3SJooyung Han address_register = registers.GeneralRegister() 46*5f39d1b3SJooyung Han inputs.append(address_register) 47*5f39d1b3SJooyung Han emitter.EmitAdd(address_register, last_address_register, stride) 48*5f39d1b3SJooyung Han last_address_register = address_register 49*5f39d1b3SJooyung Han return inputs 50*5f39d1b3SJooyung Han 51*5f39d1b3SJooyung Han 52*5f39d1b3SJooyung Handef _GenerateClear(emitter, clear_type, block): 53*5f39d1b3SJooyung Han for row in block: 54*5f39d1b3SJooyung Han emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0)) 55*5f39d1b3SJooyung Han 56*5f39d1b3SJooyung Han 57*5f39d1b3SJooyung Handef _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count, 58*5f39d1b3SJooyung Han aggregators, inputs, output): 59*5f39d1b3SJooyung Han """Emit inner loop code for reading N lanes and interweaving them.""" 60*5f39d1b3SJooyung Han emitter.EmitNewline() 61*5f39d1b3SJooyung Han emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count, 62*5f39d1b3SJooyung Han elements_count)) 63*5f39d1b3SJooyung Han 64*5f39d1b3SJooyung Han block = [registers.DoubleRegister() for unused_i in range(lanes_count)] 65*5f39d1b3SJooyung Han 66*5f39d1b3SJooyung Han if elements_count is not 8: 67*5f39d1b3SJooyung Han _GenerateClear(emitter, 'i8', block) 68*5f39d1b3SJooyung Han 69*5f39d1b3SJooyung Han for (row, input_address) in zip(block, inputs): 70*5f39d1b3SJooyung Han emitter.EmitVLoadE(8, elements_count, row, input_address, None) 71*5f39d1b3SJooyung Han 72*5f39d1b3SJooyung Han for (aggregator, row) in zip(aggregators, block): 73*5f39d1b3SJooyung Han emitter.EmitVAddw('u8', aggregator, aggregator, row) 74*5f39d1b3SJooyung Han 75*5f39d1b3SJooyung Han emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, 76*5f39d1b3SJooyung Han _AlignForLanes(lanes_count)) 77*5f39d1b3SJooyung Han 78*5f39d1b3SJooyung Han registers.FreeRegisters(block) 79*5f39d1b3SJooyung Han 80*5f39d1b3SJooyung Han 81*5f39d1b3SJooyung Handef _LoadMemoryParameter(emitter, registers, name, source): 82*5f39d1b3SJooyung Han register = registers.GeneralRegister() 83*5f39d1b3SJooyung Han emitter.EmitLdr(register, registers.MapMemoryParameter(name, source)) 84*5f39d1b3SJooyung Han return register 85*5f39d1b3SJooyung Han 86*5f39d1b3SJooyung Han 87*5f39d1b3SJooyung Handef _GenerateAggregatorReductionLowRegisters(emitter, registers, 88*5f39d1b3SJooyung Han aggregators, output_address): 89*5f39d1b3SJooyung Han emitter.EmitNewline() 90*5f39d1b3SJooyung Han emitter.EmitComment('Aggregator Reduction.') 91*5f39d1b3SJooyung Han _GenerateAggregatorReduction( 92*5f39d1b3SJooyung Han emitter, registers, aggregators, output_address, 93*5f39d1b3SJooyung Han _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset', 94*5f39d1b3SJooyung Han 'params.multiplicative_sum_offset'), 95*5f39d1b3SJooyung Han _LoadMemoryParameter(emitter, registers, 'additive_sum_offset', 96*5f39d1b3SJooyung Han 'params.additive_sum_offset')) 97*5f39d1b3SJooyung Han 98*5f39d1b3SJooyung Han 99*5f39d1b3SJooyung Handef _GenerateAggregatorReductionHighRegisters(emitter, registers, 100*5f39d1b3SJooyung Han aggregators, output_address): 101*5f39d1b3SJooyung Han emitter.EmitNewline() 102*5f39d1b3SJooyung Han emitter.EmitComment('Aggregator Reduction.') 103*5f39d1b3SJooyung Han _GenerateAggregatorReduction( 104*5f39d1b3SJooyung Han emitter, registers, aggregators, output_address, 105*5f39d1b3SJooyung Han registers.MapParameter('multiplicative_sum_offset', 106*5f39d1b3SJooyung Han 'params.multiplicative_sum_offset'), 107*5f39d1b3SJooyung Han registers.MapParameter('additive_sum_offset', 108*5f39d1b3SJooyung Han 'params.additive_sum_offset')) 109*5f39d1b3SJooyung Han 110*5f39d1b3SJooyung Han 111*5f39d1b3SJooyung Handef _GenerateAggregatorReduction(emitter, registers, aggregators, 112*5f39d1b3SJooyung Han output_address, multiplicative_sum_offset, 113*5f39d1b3SJooyung Han additive_sum_offset): 114*5f39d1b3SJooyung Han """Reduce 4 lane sum aggregators to 1 value and store the sums.""" 115*5f39d1b3SJooyung Han multiplier = registers.DoubleRegister() 116*5f39d1b3SJooyung Han emitter.EmitVMov('32', 117*5f39d1b3SJooyung Han emitter.Lane(32, multiplier, 0), multiplicative_sum_offset) 118*5f39d1b3SJooyung Han 119*5f39d1b3SJooyung Han offset = registers.QuadRegister() 120*5f39d1b3SJooyung Han emitter.EmitVDup('32', offset, additive_sum_offset) 121*5f39d1b3SJooyung Han 122*5f39d1b3SJooyung Han for aggregator in aggregators: 123*5f39d1b3SJooyung Han emitter.EmitVPaddl('u16', aggregator, aggregator) 124*5f39d1b3SJooyung Han 125*5f39d1b3SJooyung Han reduced_count = (len(aggregators) + 3) / 4 126*5f39d1b3SJooyung Han reduced = aggregators[:reduced_count] 127*5f39d1b3SJooyung Han 128*5f39d1b3SJooyung Han emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators) 129*5f39d1b3SJooyung Han 130*5f39d1b3SJooyung Han for temp in reduced: 131*5f39d1b3SJooyung Han emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0)) 132*5f39d1b3SJooyung Han 133*5f39d1b3SJooyung Han for temp in reduced: 134*5f39d1b3SJooyung Han emitter.EmitVAdd('i32', temp, temp, offset) 135*5f39d1b3SJooyung Han 136*5f39d1b3SJooyung Han emitter.EmitVStoreA(1, 32, reduced, 137*5f39d1b3SJooyung Han emitter.Dereference(output_address, 138*5f39d1b3SJooyung Han _AlignForSums(len(aggregators)))) 139*5f39d1b3SJooyung Han 140*5f39d1b3SJooyung Han 141*5f39d1b3SJooyung Hanclass RowMajorWithSumUInt8x8(common.StreamGenerator): 142*5f39d1b3SJooyung Han """.""" 143*5f39d1b3SJooyung Han 144*5f39d1b3SJooyung Han def __init__(self, emitter, asm_emitter): 145*5f39d1b3SJooyung Han common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum') 146*5f39d1b3SJooyung Han self.asm_emitter = asm_emitter 147*5f39d1b3SJooyung Han 148*5f39d1b3SJooyung Han def EmitPack(self, in_type, lanes_count, pack_size, leftovers): 149*5f39d1b3SJooyung Han assert pack_size is 8 150*5f39d1b3SJooyung Han assert in_type is 'uint8_t' 151*5f39d1b3SJooyung Han 152*5f39d1b3SJooyung Han registers = self.asm_emitter.CreateRegisters() 153*5f39d1b3SJooyung Han 154*5f39d1b3SJooyung Han self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') 155*5f39d1b3SJooyung Han 156*5f39d1b3SJooyung Han self.asm_emitter.PushIndent(self.emitter.indent) 157*5f39d1b3SJooyung Han self.asm_emitter.EmitAsmBegin() 158*5f39d1b3SJooyung Han 159*5f39d1b3SJooyung Han count = registers.MapOutputParameter('count', 'params_count_copy') 160*5f39d1b3SJooyung Han output = registers.MapOutputParameter('out') 161*5f39d1b3SJooyung Han inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count, 162*5f39d1b3SJooyung Han registers.MapOutputParameter('in'), 163*5f39d1b3SJooyung Han registers.MapParameter('stride', 'params.stride')) 164*5f39d1b3SJooyung Han aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] 165*5f39d1b3SJooyung Han 166*5f39d1b3SJooyung Han _GenerateClear(self.asm_emitter, 'i16', aggregators) 167*5f39d1b3SJooyung Han 168*5f39d1b3SJooyung Han if leftovers: 169*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 170*5f39d1b3SJooyung Han self.asm_emitter.EmitComment('Reduce count by leftovers.') 171*5f39d1b3SJooyung Han self.asm_emitter.EmitSubs(count, count, 172*5f39d1b3SJooyung Han self.asm_emitter.ImmediateConstant(leftovers)) 173*5f39d1b3SJooyung Han self.asm_emitter.EmitBeqFront(2) 174*5f39d1b3SJooyung Han 175*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 176*5f39d1b3SJooyung Han self.asm_emitter.EmitNumericalLabel(1) 177*5f39d1b3SJooyung Han self.asm_emitter.EmitSubs(count, count, 178*5f39d1b3SJooyung Han self.asm_emitter.ImmediateConstant(8)) 179*5f39d1b3SJooyung Han 180*5f39d1b3SJooyung Han _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, 181*5f39d1b3SJooyung Han aggregators, inputs, output) 182*5f39d1b3SJooyung Han 183*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 184*5f39d1b3SJooyung Han self.asm_emitter.EmitBneBack(1) 185*5f39d1b3SJooyung Han 186*5f39d1b3SJooyung Han if leftovers: 187*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 188*5f39d1b3SJooyung Han self.asm_emitter.EmitNumericalLabel(2) 189*5f39d1b3SJooyung Han _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 190*5f39d1b3SJooyung Han leftovers, aggregators, inputs, output) 191*5f39d1b3SJooyung Han 192*5f39d1b3SJooyung Han registers.FreeRegisters(inputs) 193*5f39d1b3SJooyung Han 194*5f39d1b3SJooyung Han if len(inputs) <= 6: 195*5f39d1b3SJooyung Han _GenerateAggregatorReductionHighRegisters( 196*5f39d1b3SJooyung Han self.asm_emitter, registers, aggregators, output) 197*5f39d1b3SJooyung Han else: 198*5f39d1b3SJooyung Han _GenerateAggregatorReductionLowRegisters( 199*5f39d1b3SJooyung Han self.asm_emitter, registers, aggregators, output) 200*5f39d1b3SJooyung Han 201*5f39d1b3SJooyung Han self.asm_emitter.EmitAsmEnd(registers) 202*5f39d1b3SJooyung Han self.asm_emitter.PopIndent(len(self.emitter.indent)) 203*5f39d1b3SJooyung Han 204*5f39d1b3SJooyung Han 205*5f39d1b3SJooyung Handef _GenerateColLoadAggregateStore(emitter, registers, lanes_count, 206*5f39d1b3SJooyung Han elements_count, aggregators, input_address, 207*5f39d1b3SJooyung Han stride, output): 208*5f39d1b3SJooyung Han """Emit inner loop code for reading N col lanes and interweaving them.""" 209*5f39d1b3SJooyung Han emitter.EmitNewline() 210*5f39d1b3SJooyung Han emitter.EmitComment('Load Aggregate Store - column major %dx%d' % 211*5f39d1b3SJooyung Han (lanes_count, elements_count)) 212*5f39d1b3SJooyung Han 213*5f39d1b3SJooyung Han block = [registers.DoubleRegister() for unused_i in range(lanes_count)] 214*5f39d1b3SJooyung Han 215*5f39d1b3SJooyung Han if elements_count is not 8: 216*5f39d1b3SJooyung Han _GenerateClear(emitter, 'i8', block) 217*5f39d1b3SJooyung Han 218*5f39d1b3SJooyung Han block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count, 219*5f39d1b3SJooyung Han block, input_address, stride) 220*5f39d1b3SJooyung Han 221*5f39d1b3SJooyung Han for (aggregator, row) in zip(aggregators, block): 222*5f39d1b3SJooyung Han emitter.EmitVAddw('u8', aggregator, aggregator, row) 223*5f39d1b3SJooyung Han 224*5f39d1b3SJooyung Han emitter.EmitVStoreAE(8, 8 * lanes_count, block, output, 225*5f39d1b3SJooyung Han _AlignForLanes(lanes_count)) 226*5f39d1b3SJooyung Han 227*5f39d1b3SJooyung Han registers.FreeRegisters(block) 228*5f39d1b3SJooyung Han 229*5f39d1b3SJooyung Han 230*5f39d1b3SJooyung Hanclass ColumnMajorWithSumUInt8x8(common.StreamGenerator): 231*5f39d1b3SJooyung Han """.""" 232*5f39d1b3SJooyung Han 233*5f39d1b3SJooyung Han def __init__(self, emitter, asm_emitter): 234*5f39d1b3SJooyung Han common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum') 235*5f39d1b3SJooyung Han self.asm_emitter = asm_emitter 236*5f39d1b3SJooyung Han 237*5f39d1b3SJooyung Han def EmitPack(self, in_type, lanes_count, pack_size, leftovers): 238*5f39d1b3SJooyung Han assert pack_size is 8 239*5f39d1b3SJooyung Han assert in_type is 'uint8_t' 240*5f39d1b3SJooyung Han 241*5f39d1b3SJooyung Han registers = self.asm_emitter.CreateRegisters() 242*5f39d1b3SJooyung Han 243*5f39d1b3SJooyung Han self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') 244*5f39d1b3SJooyung Han self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride') 245*5f39d1b3SJooyung Han 246*5f39d1b3SJooyung Han self.asm_emitter.PushIndent(self.emitter.indent) 247*5f39d1b3SJooyung Han self.asm_emitter.EmitAsmBegin() 248*5f39d1b3SJooyung Han 249*5f39d1b3SJooyung Han count = registers.MapOutputParameter('count', 'params_count_copy') 250*5f39d1b3SJooyung Han input_address = registers.MapOutputParameter('in') 251*5f39d1b3SJooyung Han output_address = registers.MapOutputParameter('out') 252*5f39d1b3SJooyung Han aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)] 253*5f39d1b3SJooyung Han stride = registers.MapOutputParameter('stride', 'params_stride_copy') 254*5f39d1b3SJooyung Han 255*5f39d1b3SJooyung Han self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride) 256*5f39d1b3SJooyung Han 257*5f39d1b3SJooyung Han _GenerateClear(self.asm_emitter, 'i16', aggregators) 258*5f39d1b3SJooyung Han 259*5f39d1b3SJooyung Han if leftovers: 260*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 261*5f39d1b3SJooyung Han self.asm_emitter.EmitComment('Reduce count by leftovers.') 262*5f39d1b3SJooyung Han self.asm_emitter.EmitSubs(count, count, 263*5f39d1b3SJooyung Han self.asm_emitter.ImmediateConstant(leftovers)) 264*5f39d1b3SJooyung Han self.asm_emitter.EmitBeqFront(2) 265*5f39d1b3SJooyung Han 266*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 267*5f39d1b3SJooyung Han self.asm_emitter.EmitNumericalLabel(1) 268*5f39d1b3SJooyung Han self.asm_emitter.EmitSubs(count, count, 269*5f39d1b3SJooyung Han self.asm_emitter.ImmediateConstant(8)) 270*5f39d1b3SJooyung Han 271*5f39d1b3SJooyung Han _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8, 272*5f39d1b3SJooyung Han aggregators, input_address, stride, 273*5f39d1b3SJooyung Han output_address) 274*5f39d1b3SJooyung Han 275*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 276*5f39d1b3SJooyung Han self.asm_emitter.EmitBneBack(1) 277*5f39d1b3SJooyung Han 278*5f39d1b3SJooyung Han if leftovers: 279*5f39d1b3SJooyung Han self.asm_emitter.EmitNewline() 280*5f39d1b3SJooyung Han self.asm_emitter.EmitNumericalLabel(2) 281*5f39d1b3SJooyung Han _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 282*5f39d1b3SJooyung Han leftovers, aggregators, input_address, 283*5f39d1b3SJooyung Han stride, output_address) 284*5f39d1b3SJooyung Han 285*5f39d1b3SJooyung Han 286*5f39d1b3SJooyung Han _GenerateAggregatorReductionHighRegisters( 287*5f39d1b3SJooyung Han self.asm_emitter, registers, aggregators, output_address) 288*5f39d1b3SJooyung Han 289*5f39d1b3SJooyung Han self.asm_emitter.EmitAsmEnd(registers) 290*5f39d1b3SJooyung Han self.asm_emitter.PopIndent(len(self.emitter.indent)) 291*5f39d1b3SJooyung Han 292*5f39d1b3SJooyung Han 293*5f39d1b3SJooyung Handef GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count): 294*5f39d1b3SJooyung Han row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter) 295*5f39d1b3SJooyung Han column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter) 296*5f39d1b3SJooyung Han 297*5f39d1b3SJooyung Han for lanes_count in range(1, 1 + lanes_count): 298*5f39d1b3SJooyung Han for leftovers in range(8): 299*5f39d1b3SJooyung Han row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers) 300*5f39d1b3SJooyung Han 301*5f39d1b3SJooyung Han for lanes_count in range(1, 1 + lanes_count): 302*5f39d1b3SJooyung Han for leftovers in range(8): 303*5f39d1b3SJooyung Han column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, 304*5f39d1b3SJooyung Han leftovers) 305