1 /* 2 * Copyright (c) 2018-2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_LSTMPARAMS_H 25 #define ARM_COMPUTE_LSTMPARAMS_H 26 27 #include "arm_compute/core/Types.h" 28 #include "arm_compute/runtime/Tensor.h" 29 30 #include <cstddef> 31 #include <memory> 32 33 namespace arm_compute 34 { 35 template <typename T> 36 class LSTMParams 37 { 38 public: 39 /** Constructor */ LSTMParams()40 LSTMParams() 41 : _input_to_input_weights(nullptr), 42 _recurrent_to_input_weights(nullptr), 43 _cell_to_input_weights(nullptr), 44 _input_gate_bias(nullptr), 45 _cell_to_forget_weights(nullptr), 46 _cell_to_output_weights(nullptr), 47 _projection_weights(nullptr), 48 _projection_bias(nullptr), 49 _input_layer_norm_weights(nullptr), 50 _forget_layer_norm_weights(nullptr), 51 _cell_layer_norm_weights(nullptr), 52 _output_layer_norm_weights(nullptr), 53 _cell_clip(0.f), 54 _projection_clip(0.0f), 55 _input_intermediate_scale(0.0f), 56 _forget_intermediate_scale(0.0f), 57 _cell_intermediate_scale(0.0f), 58 _output_intermediate_scale(0.0f), 59 _hidden_state_zero(0), 60 _hidden_state_scale(0.0f), 61 _has_peephole_opt(false), 62 _has_projection(false), 63 _has_cifg_opt(true), 64 _use_layer_norm(false) 65 { 66 } 67 /** Prevent instances of this class from being copied (As this class contains pointers) */ 68 LSTMParams(const LSTMParams &) = delete; 69 /** Prevent instances of this class from being copied (As this class contains pointers) */ 70 LSTMParams &operator=(const LSTMParams &) = delete; 71 /** Default destructor */ 72 ~LSTMParams() = default; 73 /** Set CIFG tensor parameters. 74 * 75 * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32. 76 * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights. 77 * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights. 78 * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8 79 * 80 * @return Reference to this LSTMParams object 81 */ set_cifg_params(const T * input_to_input_weights,const T * recurrent_to_input_weights,T * cell_to_input_weights,const T * input_gate_bias)82 LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias) 83 { 84 _input_to_input_weights = input_to_input_weights; 85 _recurrent_to_input_weights = recurrent_to_input_weights; 86 _cell_to_input_weights = cell_to_input_weights; 87 _input_gate_bias = input_gate_bias; 88 _has_cifg_opt = false; 89 return *this; 90 } 91 /** Set projection tensor parameters. 92 * 93 * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32. 94 * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8. 95 * 96 * @return Reference to this LSTMParams object 97 */ set_projection_params(const T * projection_weights,const T * projection_bias)98 LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias) 99 { 100 _projection_weights = projection_weights; 101 _projection_bias = projection_bias; 102 _has_projection = true; 103 return *this; 104 } 105 /** Set peephole tensor parameters. 106 * 107 * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32. 108 * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights. 109 * 110 * @return Reference to this LSTMParams object 111 */ set_peephole_params(T * cell_to_forget_weights,T * cell_to_output_weights)112 LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights) 113 { 114 _cell_to_forget_weights = cell_to_forget_weights; 115 _cell_to_output_weights = cell_to_output_weights; 116 _has_peephole_opt = true; 117 return *this; 118 } 119 /** Set layer normalization tensor parameters. 120 * 121 * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32. 122 * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 123 * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 124 * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 125 * 126 * @return Reference to this LSTMParams object 127 */ set_layer_normalization_params(T * input_layer_norm_weights,T * forget_layer_norm_weights,T * cell_layer_norm_weights,T * output_layer_norm_weights)128 LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights, 129 T *cell_layer_norm_weights, T *output_layer_norm_weights) 130 { 131 _input_layer_norm_weights = input_layer_norm_weights; 132 _forget_layer_norm_weights = forget_layer_norm_weights; 133 _cell_layer_norm_weights = cell_layer_norm_weights; 134 _output_layer_norm_weights = output_layer_norm_weights; 135 _use_layer_norm = true; 136 return *this; 137 } 138 139 /** Set cell clip value. 140 * 141 * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation. 142 * 143 * @return Reference to this LSTMParams object 144 */ set_cell_clip_params(float cell_clip)145 LSTMParams &set_cell_clip_params(float cell_clip) 146 { 147 _cell_clip = cell_clip; 148 return *this; 149 } 150 151 /** Set projection clip value. 152 * 153 * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled. 154 * 155 * @return Reference to this LSTMParams object 156 */ set_projection_clip_params(float projection_clip)157 LSTMParams &set_projection_clip_params(float projection_clip) 158 { 159 _projection_clip = projection_clip; 160 return *this; 161 } 162 163 /** Set scale of the intermediate results of matmul of each layer parameters. 164 * 165 * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate. 166 * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate. 167 * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate. 168 * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate. 169 * 170 * @return Reference to this LSTMParams object 171 */ set_matmul_scale_params(float input_intermediate_scale,float forget_intermediate_scale,float cell_intermediate_scale,float output_intermediate_scale)172 LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale) 173 { 174 _input_intermediate_scale = input_intermediate_scale; 175 _forget_intermediate_scale = forget_intermediate_scale; 176 _cell_intermediate_scale = cell_intermediate_scale; 177 _output_intermediate_scale = output_intermediate_scale; 178 return *this; 179 } 180 181 /** Set hidden state zero and scale parameters. 182 * 183 * @param[in] hidden_state_zero The zero point of the hidden state. 184 * @param[in] hidden_state_scale The scale of the hidden state. 185 * 186 * @return Reference to this LSTMParams object 187 */ set_hidden_state_params(int32_t hidden_state_zero,float hidden_state_scale)188 LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale) 189 { 190 _hidden_state_zero = hidden_state_zero; 191 _hidden_state_scale = hidden_state_scale; 192 return *this; 193 } 194 input_to_input_weights()195 const T *input_to_input_weights() const 196 { 197 return _input_to_input_weights; 198 } 199 recurrent_to_input_weights()200 const T *recurrent_to_input_weights() const 201 { 202 return _recurrent_to_input_weights; 203 } 204 cell_to_input_weights()205 T *cell_to_input_weights() const 206 { 207 return _cell_to_input_weights; 208 } 209 input_gate_bias()210 const T *input_gate_bias() const 211 { 212 return _input_gate_bias; 213 } 214 cell_to_forget_weights()215 T *cell_to_forget_weights() const 216 { 217 return _cell_to_forget_weights; 218 } 219 cell_to_output_weights()220 T *cell_to_output_weights() const 221 { 222 return _cell_to_output_weights; 223 } 224 projection_weights()225 const T *projection_weights() const 226 { 227 return _projection_weights; 228 } 229 projection_bias()230 const T *projection_bias() const 231 { 232 return _projection_bias; 233 } 234 input_layer_norm_weights()235 T *input_layer_norm_weights() const 236 { 237 return _input_layer_norm_weights; 238 } 239 forget_layer_norm_weights()240 T *forget_layer_norm_weights() const 241 { 242 return _forget_layer_norm_weights; 243 } 244 cell_layer_norm_weights()245 T *cell_layer_norm_weights() const 246 { 247 return _cell_layer_norm_weights; 248 } 249 output_layer_norm_weights()250 T *output_layer_norm_weights() const 251 { 252 return _output_layer_norm_weights; 253 } 254 cell_clip()255 float cell_clip() const 256 { 257 return _cell_clip; 258 } 259 projection_clip()260 float projection_clip() const 261 { 262 return _projection_clip; 263 } 264 input_intermediate_scale()265 float input_intermediate_scale() const 266 { 267 return _input_intermediate_scale; 268 } 269 forget_intermediate_scale()270 float forget_intermediate_scale() const 271 { 272 return _forget_intermediate_scale; 273 } 274 cell_intermediate_scale()275 float cell_intermediate_scale() const 276 { 277 return _cell_intermediate_scale; 278 } 279 output_intermediate_scale()280 float output_intermediate_scale() const 281 { 282 return _output_intermediate_scale; 283 } 284 hidden_state_zero()285 int32_t hidden_state_zero() const 286 { 287 return _hidden_state_zero; 288 } 289 hidden_state_scale()290 float hidden_state_scale() const 291 { 292 return _hidden_state_scale; 293 } 294 has_peephole_opt()295 bool has_peephole_opt() const 296 { 297 return _has_peephole_opt; 298 } 299 has_projection()300 bool has_projection() const 301 { 302 return _has_projection; 303 } 304 has_cifg_opt()305 bool has_cifg_opt() const 306 { 307 return _has_cifg_opt; 308 } 309 use_layer_norm()310 bool use_layer_norm() const 311 { 312 return _use_layer_norm; 313 } 314 315 private: 316 const T *_input_to_input_weights; 317 const T *_recurrent_to_input_weights; 318 T *_cell_to_input_weights; 319 const T *_input_gate_bias; 320 T *_cell_to_forget_weights; 321 T *_cell_to_output_weights; 322 const T *_projection_weights; 323 const T *_projection_bias; 324 T *_input_layer_norm_weights; 325 T *_forget_layer_norm_weights; 326 T *_cell_layer_norm_weights; 327 T *_output_layer_norm_weights; 328 float _cell_clip; 329 float _projection_clip; 330 float _input_intermediate_scale; 331 float _forget_intermediate_scale; 332 float _cell_intermediate_scale; 333 float _output_intermediate_scale; 334 int32_t _hidden_state_zero; 335 float _hidden_state_scale; 336 bool _has_peephole_opt; 337 bool _has_projection; 338 bool _has_cifg_opt; 339 bool _use_layer_norm; 340 }; 341 } 342 #endif /*ARM_COMPUTE_LSTMPARAMS_H */ 343