1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2018-2022 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE 25*c217d954SCole Faust #define ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE 26*c217d954SCole Faust 27*c217d954SCole Faust #include "tests/Globals.h" 28*c217d954SCole Faust #include "tests/framework/Asserts.h" 29*c217d954SCole Faust #include "tests/framework/Fixture.h" 30*c217d954SCole Faust #include "tests/validation/reference/ActivationLayer.h" 31*c217d954SCole Faust #include "tests/validation/reference/ArithmeticOperations.h" 32*c217d954SCole Faust #include "tests/validation/reference/ConcatenateLayer.h" 33*c217d954SCole Faust #include "tests/validation/reference/FullyConnectedLayer.h" 34*c217d954SCole Faust #include "tests/validation/reference/GEMM.h" 35*c217d954SCole Faust #include "tests/validation/reference/MeanStdDevNormalizationLayer.h" 36*c217d954SCole Faust #include "tests/validation/reference/PixelWiseMultiplication.h" 37*c217d954SCole Faust #include "tests/validation/reference/Transpose.h" 38*c217d954SCole Faust 39*c217d954SCole Faust namespace arm_compute 40*c217d954SCole Faust { 41*c217d954SCole Faust namespace test 42*c217d954SCole Faust { 43*c217d954SCole Faust namespace validation 44*c217d954SCole Faust { 45*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename FunctionParams, typename T> 46*c217d954SCole Faust class LSTMLayerValidationFixture : public framework::Fixture 47*c217d954SCole Faust { 48*c217d954SCole Faust public: 49*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,TensorShape input_weights_shape,TensorShape recurrent_weights_shape,TensorShape cell_bias_shape,TensorShape output_cell_shape,TensorShape output_shape,TensorShape scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)50*c217d954SCole Faust void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape, 51*c217d954SCole Faust TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, 52*c217d954SCole Faust bool use_layer_norm) 53*c217d954SCole Faust { 54*c217d954SCole Faust _target = compute_target(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold, 55*c217d954SCole Faust data_type, projection_opt, peephole_opt, use_layer_norm); 56*c217d954SCole Faust _reference = compute_reference(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold, 57*c217d954SCole Faust data_type, projection_opt, peephole_opt, use_layer_norm); 58*c217d954SCole Faust } 59*c217d954SCole Faust 60*c217d954SCole Faust protected: 61*c217d954SCole Faust template <typename U> fill(U && tensor,int i)62*c217d954SCole Faust void fill(U &&tensor, int i) 63*c217d954SCole Faust { 64*c217d954SCole Faust static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); 65*c217d954SCole Faust using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; 66*c217d954SCole Faust 67*c217d954SCole Faust DistributionType distribution{ T(-1.0f), T(1.0f) }; 68*c217d954SCole Faust library->fill(tensor, distribution, i); 69*c217d954SCole Faust } 70*c217d954SCole Faust template <typename U> fill_custom_val(U && tensor,float num,int i)71*c217d954SCole Faust void fill_custom_val(U &&tensor, float num, int i) 72*c217d954SCole Faust { 73*c217d954SCole Faust static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); 74*c217d954SCole Faust using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; 75*c217d954SCole Faust 76*c217d954SCole Faust DistributionType distribution{ T(num), T(num) }; 77*c217d954SCole Faust library->fill(tensor, distribution, i); 78*c217d954SCole Faust } compute_target(const TensorShape & input_shape,const TensorShape & input_weights_shape,const TensorShape & recurrent_weights_shape,const TensorShape & cell_bias_shape,const TensorShape & output_cell_shape,const TensorShape & output_shape,const TensorShape & scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)79*c217d954SCole Faust TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape, 80*c217d954SCole Faust const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold, 81*c217d954SCole Faust float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm) 82*c217d954SCole Faust { 83*c217d954SCole Faust const unsigned int num_cells = input_weights_shape.y(); 84*c217d954SCole Faust const unsigned int num_outputs = recurrent_weights_shape.x(); 85*c217d954SCole Faust 86*c217d954SCole Faust // Create tensors 87*c217d954SCole Faust TensorType input = create_tensor<TensorType>(input_shape, data_type); 88*c217d954SCole Faust TensorType input_to_forget_w = create_tensor<TensorType>(input_weights_shape, data_type); 89*c217d954SCole Faust TensorType input_to_cell_w = create_tensor<TensorType>(input_weights_shape, data_type); 90*c217d954SCole Faust TensorType input_to_output_w = create_tensor<TensorType>(input_weights_shape, data_type); 91*c217d954SCole Faust TensorType recurrent_to_forget_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 92*c217d954SCole Faust TensorType recurrent_to_cell_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 93*c217d954SCole Faust TensorType recurrent_to_output_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 94*c217d954SCole Faust TensorType forget_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 95*c217d954SCole Faust TensorType cell_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 96*c217d954SCole Faust TensorType output_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 97*c217d954SCole Faust TensorType output_state_in = create_tensor<TensorType>(output_shape, data_type); 98*c217d954SCole Faust TensorType cell_state_in = create_tensor<TensorType>(output_cell_shape, data_type); 99*c217d954SCole Faust TensorType scratch = create_tensor<TensorType>(scratch_shape, data_type); 100*c217d954SCole Faust TensorType output_state_out = create_tensor<TensorType>(output_shape, data_type); 101*c217d954SCole Faust TensorType cell_state_out = create_tensor<TensorType>(output_cell_shape, data_type); 102*c217d954SCole Faust TensorType output = create_tensor<TensorType>(output_shape, data_type); 103*c217d954SCole Faust TensorType input_to_input_w; 104*c217d954SCole Faust TensorType recurrent_to_input_w; 105*c217d954SCole Faust TensorType cell_to_input_w; 106*c217d954SCole Faust TensorType cell_to_forget_w; 107*c217d954SCole Faust TensorType input_gate_bias; 108*c217d954SCole Faust TensorType cell_to_output_w; 109*c217d954SCole Faust TensorType projection_w; 110*c217d954SCole Faust TensorType projection_bias; 111*c217d954SCole Faust TensorType input_layer_norm_w; 112*c217d954SCole Faust TensorType forget_layer_norm_w; 113*c217d954SCole Faust TensorType cell_layer_norm_w; 114*c217d954SCole Faust TensorType output_layer_norm_w; 115*c217d954SCole Faust 116*c217d954SCole Faust bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true; 117*c217d954SCole Faust 118*c217d954SCole Faust FunctionParams lstm_params; 119*c217d954SCole Faust 120*c217d954SCole Faust if(!cifg_opt) 121*c217d954SCole Faust { 122*c217d954SCole Faust input_to_input_w = create_tensor<TensorType>(input_weights_shape, data_type); 123*c217d954SCole Faust recurrent_to_input_w = create_tensor<TensorType>(recurrent_weights_shape, data_type); 124*c217d954SCole Faust if(peephole_opt) 125*c217d954SCole Faust { 126*c217d954SCole Faust cell_to_input_w = create_tensor<TensorType>(cell_bias_shape, data_type); 127*c217d954SCole Faust } 128*c217d954SCole Faust input_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type); 129*c217d954SCole Faust lstm_params.set_cifg_params(&input_to_input_w, &recurrent_to_input_w, &cell_to_input_w, &input_gate_bias); 130*c217d954SCole Faust } 131*c217d954SCole Faust 132*c217d954SCole Faust if(peephole_opt) 133*c217d954SCole Faust { 134*c217d954SCole Faust cell_to_forget_w = create_tensor<TensorType>(cell_bias_shape, data_type); 135*c217d954SCole Faust cell_to_output_w = create_tensor<TensorType>(cell_bias_shape, data_type); 136*c217d954SCole Faust lstm_params.set_peephole_params(&cell_to_forget_w, &cell_to_output_w); 137*c217d954SCole Faust } 138*c217d954SCole Faust 139*c217d954SCole Faust if(projection_opt) 140*c217d954SCole Faust { 141*c217d954SCole Faust projection_w = create_tensor<TensorType>(TensorShape(num_cells, num_outputs), data_type); 142*c217d954SCole Faust projection_bias = create_tensor<TensorType>(TensorShape(num_outputs), data_type); 143*c217d954SCole Faust lstm_params.set_projection_params(&projection_w, &projection_bias); 144*c217d954SCole Faust } 145*c217d954SCole Faust 146*c217d954SCole Faust if(use_layer_norm) 147*c217d954SCole Faust { 148*c217d954SCole Faust forget_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 149*c217d954SCole Faust cell_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 150*c217d954SCole Faust output_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 151*c217d954SCole Faust if(!cifg_opt) 152*c217d954SCole Faust { 153*c217d954SCole Faust input_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type); 154*c217d954SCole Faust lstm_params.set_layer_normalization_params(&input_layer_norm_w, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w); 155*c217d954SCole Faust } 156*c217d954SCole Faust else 157*c217d954SCole Faust { 158*c217d954SCole Faust lstm_params.set_layer_normalization_params(nullptr, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w); 159*c217d954SCole Faust } 160*c217d954SCole Faust } 161*c217d954SCole Faust 162*c217d954SCole Faust // Create and configure function 163*c217d954SCole Faust FunctionType lstm; 164*c217d954SCole Faust lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w, 165*c217d954SCole Faust &recurrent_to_cell_w, &recurrent_to_output_w, &forget_gate_bias, &cell_bias, &output_gate_bias, 166*c217d954SCole Faust &output_state_in, &cell_state_in, 167*c217d954SCole Faust &scratch, &output_state_out, &cell_state_out, &output, 168*c217d954SCole Faust lstm_params, info, cell_threshold, projection_threshold); 169*c217d954SCole Faust 170*c217d954SCole Faust ARM_COMPUTE_ASSERT(input.info()->is_resizable()); 171*c217d954SCole Faust ARM_COMPUTE_ASSERT(input_to_forget_w.info()->is_resizable()); 172*c217d954SCole Faust ARM_COMPUTE_ASSERT(input_to_cell_w.info()->is_resizable()); 173*c217d954SCole Faust ARM_COMPUTE_ASSERT(input_to_output_w.info()->is_resizable()); 174*c217d954SCole Faust ARM_COMPUTE_ASSERT(recurrent_to_forget_w.info()->is_resizable()); 175*c217d954SCole Faust ARM_COMPUTE_ASSERT(recurrent_to_cell_w.info()->is_resizable()); 176*c217d954SCole Faust ARM_COMPUTE_ASSERT(recurrent_to_output_w.info()->is_resizable()); 177*c217d954SCole Faust ARM_COMPUTE_ASSERT(forget_gate_bias.info()->is_resizable()); 178*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_bias.info()->is_resizable()); 179*c217d954SCole Faust ARM_COMPUTE_ASSERT(output_gate_bias.info()->is_resizable()); 180*c217d954SCole Faust ARM_COMPUTE_ASSERT(output_state_in.info()->is_resizable()); 181*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_state_in.info()->is_resizable()); 182*c217d954SCole Faust ARM_COMPUTE_ASSERT(scratch.info()->is_resizable()); 183*c217d954SCole Faust ARM_COMPUTE_ASSERT(output_state_out.info()->is_resizable()); 184*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_state_out.info()->is_resizable()); 185*c217d954SCole Faust ARM_COMPUTE_ASSERT(output.info()->is_resizable()); 186*c217d954SCole Faust 187*c217d954SCole Faust // Allocate tensors 188*c217d954SCole Faust input.allocator()->allocate(); 189*c217d954SCole Faust input_to_forget_w.allocator()->allocate(); 190*c217d954SCole Faust input_to_cell_w.allocator()->allocate(); 191*c217d954SCole Faust input_to_output_w.allocator()->allocate(); 192*c217d954SCole Faust recurrent_to_forget_w.allocator()->allocate(); 193*c217d954SCole Faust recurrent_to_cell_w.allocator()->allocate(); 194*c217d954SCole Faust recurrent_to_output_w.allocator()->allocate(); 195*c217d954SCole Faust forget_gate_bias.allocator()->allocate(); 196*c217d954SCole Faust cell_bias.allocator()->allocate(); 197*c217d954SCole Faust output_gate_bias.allocator()->allocate(); 198*c217d954SCole Faust output_state_in.allocator()->allocate(); 199*c217d954SCole Faust cell_state_in.allocator()->allocate(); 200*c217d954SCole Faust scratch.allocator()->allocate(); 201*c217d954SCole Faust output_state_out.allocator()->allocate(); 202*c217d954SCole Faust cell_state_out.allocator()->allocate(); 203*c217d954SCole Faust output.allocator()->allocate(); 204*c217d954SCole Faust 205*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input.info()->is_resizable()); 206*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input_to_forget_w.info()->is_resizable()); 207*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input_to_cell_w.info()->is_resizable()); 208*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input_to_output_w.info()->is_resizable()); 209*c217d954SCole Faust ARM_COMPUTE_ASSERT(!recurrent_to_forget_w.info()->is_resizable()); 210*c217d954SCole Faust ARM_COMPUTE_ASSERT(!recurrent_to_cell_w.info()->is_resizable()); 211*c217d954SCole Faust ARM_COMPUTE_ASSERT(!recurrent_to_output_w.info()->is_resizable()); 212*c217d954SCole Faust ARM_COMPUTE_ASSERT(!forget_gate_bias.info()->is_resizable()); 213*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_bias.info()->is_resizable()); 214*c217d954SCole Faust ARM_COMPUTE_ASSERT(!output_gate_bias.info()->is_resizable()); 215*c217d954SCole Faust ARM_COMPUTE_ASSERT(!output_state_in.info()->is_resizable()); 216*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_state_in.info()->is_resizable()); 217*c217d954SCole Faust ARM_COMPUTE_ASSERT(!scratch.info()->is_resizable()); 218*c217d954SCole Faust ARM_COMPUTE_ASSERT(!output_state_out.info()->is_resizable()); 219*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_state_out.info()->is_resizable()); 220*c217d954SCole Faust ARM_COMPUTE_ASSERT(!output.info()->is_resizable()); 221*c217d954SCole Faust 222*c217d954SCole Faust // Fill tensors 223*c217d954SCole Faust fill(AccessorType(input), 0); 224*c217d954SCole Faust fill(AccessorType(input_to_forget_w), 1); 225*c217d954SCole Faust fill(AccessorType(input_to_cell_w), 2); 226*c217d954SCole Faust fill(AccessorType(input_to_output_w), 3); 227*c217d954SCole Faust fill(AccessorType(recurrent_to_forget_w), 4); 228*c217d954SCole Faust fill(AccessorType(recurrent_to_cell_w), 5); 229*c217d954SCole Faust fill(AccessorType(recurrent_to_output_w), 6); 230*c217d954SCole Faust fill(AccessorType(forget_gate_bias), 7); 231*c217d954SCole Faust fill(AccessorType(cell_bias), 8); 232*c217d954SCole Faust fill(AccessorType(output_gate_bias), 9); 233*c217d954SCole Faust fill(AccessorType(output_state_in), 10); 234*c217d954SCole Faust fill(AccessorType(cell_state_in), 11); 235*c217d954SCole Faust fill(AccessorType(scratch), 12); 236*c217d954SCole Faust 237*c217d954SCole Faust if(!cifg_opt) 238*c217d954SCole Faust { 239*c217d954SCole Faust ARM_COMPUTE_ASSERT(input_to_input_w.info()->is_resizable()); 240*c217d954SCole Faust ARM_COMPUTE_ASSERT(recurrent_to_input_w.info()->is_resizable()); 241*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_to_input_w.info()->is_resizable()); 242*c217d954SCole Faust ARM_COMPUTE_ASSERT(input_gate_bias.info()->is_resizable()); 243*c217d954SCole Faust input_to_input_w.allocator()->allocate(); 244*c217d954SCole Faust recurrent_to_input_w.allocator()->allocate(); 245*c217d954SCole Faust cell_to_input_w.allocator()->allocate(); 246*c217d954SCole Faust input_gate_bias.allocator()->allocate(); 247*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input_to_input_w.info()->is_resizable()); 248*c217d954SCole Faust ARM_COMPUTE_ASSERT(!recurrent_to_input_w.info()->is_resizable()); 249*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_to_input_w.info()->is_resizable()); 250*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input_gate_bias.info()->is_resizable()); 251*c217d954SCole Faust fill(AccessorType(input_to_input_w), 13); 252*c217d954SCole Faust fill(AccessorType(recurrent_to_input_w), 14); 253*c217d954SCole Faust if(peephole_opt) 254*c217d954SCole Faust { 255*c217d954SCole Faust fill(AccessorType(cell_to_input_w), 15); 256*c217d954SCole Faust } 257*c217d954SCole Faust fill(AccessorType(recurrent_to_input_w), 16); 258*c217d954SCole Faust fill(AccessorType(input_gate_bias), 17); 259*c217d954SCole Faust } 260*c217d954SCole Faust 261*c217d954SCole Faust if(peephole_opt) 262*c217d954SCole Faust { 263*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_to_forget_w.info()->is_resizable()); 264*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_to_output_w.info()->is_resizable()); 265*c217d954SCole Faust cell_to_forget_w.allocator()->allocate(); 266*c217d954SCole Faust cell_to_output_w.allocator()->allocate(); 267*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_to_forget_w.info()->is_resizable()); 268*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_to_output_w.info()->is_resizable()); 269*c217d954SCole Faust fill(AccessorType(cell_to_forget_w), 18); 270*c217d954SCole Faust fill(AccessorType(cell_to_output_w), 19); 271*c217d954SCole Faust } 272*c217d954SCole Faust 273*c217d954SCole Faust if(projection_opt) 274*c217d954SCole Faust { 275*c217d954SCole Faust ARM_COMPUTE_ASSERT(projection_w.info()->is_resizable()); 276*c217d954SCole Faust ARM_COMPUTE_ASSERT(projection_bias.info()->is_resizable()); 277*c217d954SCole Faust 278*c217d954SCole Faust projection_w.allocator()->allocate(); 279*c217d954SCole Faust projection_bias.allocator()->allocate(); 280*c217d954SCole Faust 281*c217d954SCole Faust ARM_COMPUTE_ASSERT(!projection_w.info()->is_resizable()); 282*c217d954SCole Faust ARM_COMPUTE_ASSERT(!projection_bias.info()->is_resizable()); 283*c217d954SCole Faust 284*c217d954SCole Faust fill(AccessorType(projection_w), 20); 285*c217d954SCole Faust fill(AccessorType(projection_bias), 21); 286*c217d954SCole Faust } 287*c217d954SCole Faust 288*c217d954SCole Faust if(use_layer_norm) 289*c217d954SCole Faust { 290*c217d954SCole Faust if(!cifg_opt) 291*c217d954SCole Faust { 292*c217d954SCole Faust ARM_COMPUTE_ASSERT(input_layer_norm_w.info()->is_resizable()); 293*c217d954SCole Faust 294*c217d954SCole Faust input_layer_norm_w.allocator()->allocate(); 295*c217d954SCole Faust 296*c217d954SCole Faust ARM_COMPUTE_ASSERT(!input_layer_norm_w.info()->is_resizable()); 297*c217d954SCole Faust 298*c217d954SCole Faust fill(AccessorType(input_layer_norm_w), 22); 299*c217d954SCole Faust } 300*c217d954SCole Faust ARM_COMPUTE_ASSERT(forget_layer_norm_w.info()->is_resizable()); 301*c217d954SCole Faust ARM_COMPUTE_ASSERT(cell_layer_norm_w.info()->is_resizable()); 302*c217d954SCole Faust ARM_COMPUTE_ASSERT(output_layer_norm_w.info()->is_resizable()); 303*c217d954SCole Faust 304*c217d954SCole Faust forget_layer_norm_w.allocator()->allocate(); 305*c217d954SCole Faust cell_layer_norm_w.allocator()->allocate(); 306*c217d954SCole Faust output_layer_norm_w.allocator()->allocate(); 307*c217d954SCole Faust 308*c217d954SCole Faust ARM_COMPUTE_ASSERT(!forget_layer_norm_w.info()->is_resizable()); 309*c217d954SCole Faust ARM_COMPUTE_ASSERT(!cell_layer_norm_w.info()->is_resizable()); 310*c217d954SCole Faust ARM_COMPUTE_ASSERT(!output_layer_norm_w.info()->is_resizable()); 311*c217d954SCole Faust 312*c217d954SCole Faust fill(AccessorType(forget_layer_norm_w), 23); 313*c217d954SCole Faust fill(AccessorType(cell_layer_norm_w), 24); 314*c217d954SCole Faust fill(AccessorType(output_layer_norm_w), 25); 315*c217d954SCole Faust } 316*c217d954SCole Faust 317*c217d954SCole Faust // Compute function 318*c217d954SCole Faust lstm.run(); 319*c217d954SCole Faust 320*c217d954SCole Faust _target_scratch = std::move(scratch); 321*c217d954SCole Faust return output; 322*c217d954SCole Faust } 323*c217d954SCole Faust compute_reference(const TensorShape & input_shape,const TensorShape & input_weights_shape,const TensorShape & recurrent_weights_shape,const TensorShape & cell_bias_shape,const TensorShape & output_cell_shape,const TensorShape & output_shape,const TensorShape & scratch_shape,ActivationLayerInfo info,float cell_threshold,float projection_threshold,DataType data_type,bool projection_opt,bool peephole_opt,bool use_layer_norm)324*c217d954SCole Faust SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape, 325*c217d954SCole Faust const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold, 326*c217d954SCole Faust float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm) 327*c217d954SCole Faust { 328*c217d954SCole Faust const unsigned int num_cells = input_weights_shape.y(); 329*c217d954SCole Faust const unsigned int num_outputs = recurrent_weights_shape.x(); 330*c217d954SCole Faust 331*c217d954SCole Faust // Create projection weights shape 332*c217d954SCole Faust TensorShape projection_weights_shape(num_cells, num_outputs); 333*c217d954SCole Faust 334*c217d954SCole Faust // Create projection bias shape 335*c217d954SCole Faust TensorShape projection_bias_shape(num_outputs); 336*c217d954SCole Faust 337*c217d954SCole Faust TensorShape gemm_shape{ 1, output_shape.y() }; 338*c217d954SCole Faust SimpleTensor<T> gemm_out{ gemm_shape, data_type }; 339*c217d954SCole Faust 340*c217d954SCole Faust // Create reference 341*c217d954SCole Faust SimpleTensor<T> input{ input_shape, data_type }; 342*c217d954SCole Faust SimpleTensor<T> input_to_input_w{ input_weights_shape, data_type }; 343*c217d954SCole Faust SimpleTensor<T> input_to_forget_w{ input_weights_shape, data_type }; 344*c217d954SCole Faust SimpleTensor<T> input_to_cell_w{ input_weights_shape, data_type }; 345*c217d954SCole Faust SimpleTensor<T> input_to_output_w{ input_weights_shape, data_type }; 346*c217d954SCole Faust SimpleTensor<T> recurrent_to_input_w{ recurrent_weights_shape, data_type }; 347*c217d954SCole Faust SimpleTensor<T> recurrent_to_forget_w{ recurrent_weights_shape, data_type }; 348*c217d954SCole Faust SimpleTensor<T> recurrent_to_cell_w{ recurrent_weights_shape, data_type }; 349*c217d954SCole Faust SimpleTensor<T> recurrent_to_output_w{ recurrent_weights_shape, data_type }; 350*c217d954SCole Faust SimpleTensor<T> cell_to_input_w{ cell_bias_shape, data_type }; 351*c217d954SCole Faust SimpleTensor<T> cell_to_forget_w{ cell_bias_shape, data_type }; 352*c217d954SCole Faust SimpleTensor<T> cell_to_output_w{ cell_bias_shape, data_type }; 353*c217d954SCole Faust SimpleTensor<T> input_gate_bias{ cell_bias_shape, data_type }; 354*c217d954SCole Faust SimpleTensor<T> forget_gate_bias{ cell_bias_shape, data_type }; 355*c217d954SCole Faust SimpleTensor<T> cell_bias{ cell_bias_shape, data_type }; 356*c217d954SCole Faust SimpleTensor<T> output_gate_bias{ cell_bias_shape, data_type }; 357*c217d954SCole Faust SimpleTensor<T> projection_w{ projection_weights_shape, data_type }; 358*c217d954SCole Faust SimpleTensor<T> projection_bias{ projection_bias_shape, data_type }; 359*c217d954SCole Faust SimpleTensor<T> output_state_in{ output_shape, data_type }; 360*c217d954SCole Faust SimpleTensor<T> cell_state_in{ output_cell_shape, data_type }; 361*c217d954SCole Faust SimpleTensor<T> scratch{ scratch_shape, data_type }; 362*c217d954SCole Faust SimpleTensor<T> output_state_out{ output_shape, data_type }; 363*c217d954SCole Faust SimpleTensor<T> cell_state_out{ output_cell_shape, data_type }; 364*c217d954SCole Faust SimpleTensor<T> output{ output_shape, data_type }; 365*c217d954SCole Faust 366*c217d954SCole Faust bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true; 367*c217d954SCole Faust 368*c217d954SCole Faust // Fill reference 369*c217d954SCole Faust fill(input, 0); 370*c217d954SCole Faust fill(input_to_forget_w, 1); 371*c217d954SCole Faust fill(input_to_cell_w, 2); 372*c217d954SCole Faust fill(input_to_output_w, 3); 373*c217d954SCole Faust fill(recurrent_to_forget_w, 4); 374*c217d954SCole Faust fill(recurrent_to_cell_w, 5); 375*c217d954SCole Faust fill(recurrent_to_output_w, 6); 376*c217d954SCole Faust if(use_layer_norm) 377*c217d954SCole Faust { 378*c217d954SCole Faust fill_custom_val(forget_gate_bias, 0.f, 7); 379*c217d954SCole Faust fill_custom_val(cell_bias, 0.f, 8); 380*c217d954SCole Faust fill_custom_val(output_gate_bias, 0.f, 9); 381*c217d954SCole Faust } 382*c217d954SCole Faust else 383*c217d954SCole Faust { 384*c217d954SCole Faust fill(forget_gate_bias, 7); 385*c217d954SCole Faust fill(cell_bias, 8); 386*c217d954SCole Faust fill(output_gate_bias, 9); 387*c217d954SCole Faust } 388*c217d954SCole Faust fill(output_state_in, 10); 389*c217d954SCole Faust fill(cell_state_in, 11); 390*c217d954SCole Faust fill(scratch, 12); 391*c217d954SCole Faust fill(input_to_input_w, 13); 392*c217d954SCole Faust fill(recurrent_to_input_w, 14); 393*c217d954SCole Faust fill(cell_to_input_w, 15); 394*c217d954SCole Faust fill(recurrent_to_input_w, 16); 395*c217d954SCole Faust if(!cifg_opt && use_layer_norm) 396*c217d954SCole Faust { 397*c217d954SCole Faust fill_custom_val(input_gate_bias, 0.f, 17); 398*c217d954SCole Faust } 399*c217d954SCole Faust else 400*c217d954SCole Faust { 401*c217d954SCole Faust fill(input_gate_bias, 17); 402*c217d954SCole Faust } 403*c217d954SCole Faust fill(cell_to_forget_w, 18); 404*c217d954SCole Faust fill(cell_to_output_w, 19); 405*c217d954SCole Faust fill(projection_w, 20); 406*c217d954SCole Faust fill(projection_bias, 21); 407*c217d954SCole Faust 408*c217d954SCole Faust // Compute forget_gate 409*c217d954SCole Faust SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape); 410*c217d954SCole Faust SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w); 411*c217d954SCole Faust SimpleTensor<T> gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f); 412*c217d954SCole Faust SimpleTensor<T> forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE); 413*c217d954SCole Faust 414*c217d954SCole Faust if(peephole_opt) 415*c217d954SCole Faust { 416*c217d954SCole Faust SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type); 417*c217d954SCole Faust forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); 418*c217d954SCole Faust } 419*c217d954SCole Faust 420*c217d954SCole Faust if(use_layer_norm) 421*c217d954SCole Faust { 422*c217d954SCole Faust SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type }; 423*c217d954SCole Faust fill(forget_layer_norm_w, 23); 424*c217d954SCole Faust forget_gate = reference::mean_std_normalization_layer(forget_gate); 425*c217d954SCole Faust forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 426*c217d954SCole Faust fill(forget_gate_bias, 7); 427*c217d954SCole Faust forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE); 428*c217d954SCole Faust } 429*c217d954SCole Faust forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); 430*c217d954SCole Faust 431*c217d954SCole Faust // Compute input_gate 432*c217d954SCole Faust SimpleTensor<T> input_gate; 433*c217d954SCole Faust if(cifg_opt) 434*c217d954SCole Faust { 435*c217d954SCole Faust SimpleTensor<T> ones{ cell_bias_shape, data_type }; 436*c217d954SCole Faust fill_custom_val(ones, 1.f, 0); 437*c217d954SCole Faust input_gate = reference::arithmetic_operation<T>(reference::ArithmeticOperation::SUB, ones, forget_gate, data_type, ConvertPolicy::SATURATE); 438*c217d954SCole Faust } 439*c217d954SCole Faust else 440*c217d954SCole Faust { 441*c217d954SCole Faust SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); 442*c217d954SCole Faust transposed_weights = reference::transpose(recurrent_to_input_w); 443*c217d954SCole Faust gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f); 444*c217d954SCole Faust input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); 445*c217d954SCole Faust if(peephole_opt) 446*c217d954SCole Faust { 447*c217d954SCole Faust SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 448*c217d954SCole Faust input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); 449*c217d954SCole Faust } 450*c217d954SCole Faust if(use_layer_norm) 451*c217d954SCole Faust { 452*c217d954SCole Faust SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type }; 453*c217d954SCole Faust fill(input_layer_norm_w, 22); 454*c217d954SCole Faust input_gate = reference::mean_std_normalization_layer(input_gate); 455*c217d954SCole Faust input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 456*c217d954SCole Faust fill(input_gate_bias, 17); 457*c217d954SCole Faust input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE); 458*c217d954SCole Faust } 459*c217d954SCole Faust input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); 460*c217d954SCole Faust } 461*c217d954SCole Faust // Compute cell_state 462*c217d954SCole Faust SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape); 463*c217d954SCole Faust transposed_weights = reference::transpose(recurrent_to_cell_w); 464*c217d954SCole Faust gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); 465*c217d954SCole Faust SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 466*c217d954SCole Faust cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE); 467*c217d954SCole Faust if(use_layer_norm) 468*c217d954SCole Faust { 469*c217d954SCole Faust SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type }; 470*c217d954SCole Faust fill(cell_layer_norm_w, 24); 471*c217d954SCole Faust cell_state_out = reference::mean_std_normalization_layer(cell_state_out); 472*c217d954SCole Faust cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 473*c217d954SCole Faust fill(cell_bias, 8); 474*c217d954SCole Faust cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE); 475*c217d954SCole Faust } 476*c217d954SCole Faust cell_state_out = reference::activation_layer(cell_state_out, info); 477*c217d954SCole Faust cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 478*c217d954SCole Faust cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); 479*c217d954SCole Faust 480*c217d954SCole Faust if(cell_threshold != 0.f) 481*c217d954SCole Faust { 482*c217d954SCole Faust cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold)); 483*c217d954SCole Faust } 484*c217d954SCole Faust 485*c217d954SCole Faust // Compute output 486*c217d954SCole Faust SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape); 487*c217d954SCole Faust transposed_weights = reference::transpose(recurrent_to_output_w); 488*c217d954SCole Faust gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); 489*c217d954SCole Faust output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); 490*c217d954SCole Faust if(peephole_opt) 491*c217d954SCole Faust { 492*c217d954SCole Faust pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 493*c217d954SCole Faust output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); 494*c217d954SCole Faust } 495*c217d954SCole Faust if(use_layer_norm) 496*c217d954SCole Faust { 497*c217d954SCole Faust SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type }; 498*c217d954SCole Faust fill(output_layer_norm_w, 25); 499*c217d954SCole Faust output = reference::mean_std_normalization_layer(output); 500*c217d954SCole Faust output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 501*c217d954SCole Faust fill(output_gate_bias, 9); 502*c217d954SCole Faust output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE); 503*c217d954SCole Faust } 504*c217d954SCole Faust output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); 505*c217d954SCole Faust 506*c217d954SCole Faust // Compute output state 507*c217d954SCole Faust SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info); 508*c217d954SCole Faust output_state_out = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); 509*c217d954SCole Faust 510*c217d954SCole Faust if(projection_opt) 511*c217d954SCole Faust { 512*c217d954SCole Faust SimpleTensor<T> fully_connected_projection = reference::fully_connected_layer(output_state_out, projection_w, projection_bias, output_cell_shape); 513*c217d954SCole Faust if(projection_threshold != 0.f) 514*c217d954SCole Faust { 515*c217d954SCole Faust output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)); 516*c217d954SCole Faust } 517*c217d954SCole Faust } 518*c217d954SCole Faust std::vector<SimpleTensor<T>> scratch_inputs; 519*c217d954SCole Faust if(!cifg_opt) 520*c217d954SCole Faust { 521*c217d954SCole Faust scratch_inputs.emplace_back(std::move(input_gate)); 522*c217d954SCole Faust } 523*c217d954SCole Faust scratch_inputs.emplace_back(std::move(cell_state_out)); 524*c217d954SCole Faust scratch_inputs.emplace_back(std::move(forget_gate)); 525*c217d954SCole Faust scratch_inputs.emplace_back(std::move(output)); 526*c217d954SCole Faust scratch = reference::concatenate_layer(scratch_inputs, scratch, Window::DimX); 527*c217d954SCole Faust _reference_scratch = std::move(scratch); 528*c217d954SCole Faust return output_state_out; 529*c217d954SCole Faust } 530*c217d954SCole Faust 531*c217d954SCole Faust TensorType _target{}; 532*c217d954SCole Faust TensorType _target_scratch{}; 533*c217d954SCole Faust SimpleTensor<T> _reference{}; 534*c217d954SCole Faust SimpleTensor<T> _reference_scratch{}; 535*c217d954SCole Faust }; 536*c217d954SCole Faust } // namespace validation 537*c217d954SCole Faust } // namespace test 538*c217d954SCole Faust } // namespace arm_compute 539*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_LSTM_LAYER_FIXTURE */ 540