xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/runtime/common/LSTMParams.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_LSTMPARAMS_H
25 #define ARM_COMPUTE_LSTMPARAMS_H
26 
27 #include "arm_compute/core/Types.h"
28 #include "arm_compute/runtime/Tensor.h"
29 
30 #include <cstddef>
31 #include <memory>
32 
33 namespace arm_compute
34 {
35 template <typename T>
36 class LSTMParams
37 {
38 public:
39     /** Constructor */
LSTMParams()40     LSTMParams()
41         : _input_to_input_weights(nullptr),
42           _recurrent_to_input_weights(nullptr),
43           _cell_to_input_weights(nullptr),
44           _input_gate_bias(nullptr),
45           _cell_to_forget_weights(nullptr),
46           _cell_to_output_weights(nullptr),
47           _projection_weights(nullptr),
48           _projection_bias(nullptr),
49           _input_layer_norm_weights(nullptr),
50           _forget_layer_norm_weights(nullptr),
51           _cell_layer_norm_weights(nullptr),
52           _output_layer_norm_weights(nullptr),
53           _cell_clip(0.f),
54           _projection_clip(0.0f),
55           _input_intermediate_scale(0.0f),
56           _forget_intermediate_scale(0.0f),
57           _cell_intermediate_scale(0.0f),
58           _output_intermediate_scale(0.0f),
59           _hidden_state_zero(0),
60           _hidden_state_scale(0.0f),
61           _has_peephole_opt(false),
62           _has_projection(false),
63           _has_cifg_opt(true),
64           _use_layer_norm(false)
65     {
66     }
67     /** Prevent instances of this class from being copied (As this class contains pointers) */
68     LSTMParams(const LSTMParams &) = delete;
69     /** Prevent instances of this class from being copied (As this class contains pointers) */
70     LSTMParams &operator=(const LSTMParams &) = delete;
71     /** Default destructor */
72     ~LSTMParams() = default;
73     /** Set CIFG tensor parameters.
74      *
75      * @param[in] input_to_input_weights     2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
76      * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
77      * @param[in] cell_to_input_weights      1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
78      * @param[in] input_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
79      *
80      * @return Reference to this LSTMParams object
81      */
set_cifg_params(const T * input_to_input_weights,const T * recurrent_to_input_weights,T * cell_to_input_weights,const T * input_gate_bias)82     LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
83     {
84         _input_to_input_weights     = input_to_input_weights;
85         _recurrent_to_input_weights = recurrent_to_input_weights;
86         _cell_to_input_weights      = cell_to_input_weights;
87         _input_gate_bias            = input_gate_bias;
88         _has_cifg_opt               = false;
89         return *this;
90     }
91     /** Set projection tensor parameters.
92      *
93      * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
94      * @param[in] projection_bias    1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
95      *
96      * @return Reference to this LSTMParams object
97      */
set_projection_params(const T * projection_weights,const T * projection_bias)98     LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias)
99     {
100         _projection_weights = projection_weights;
101         _projection_bias    = projection_bias;
102         _has_projection     = true;
103         return *this;
104     }
105     /** Set peephole tensor parameters.
106      *
107      * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
108      * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
109      *
110      * @return Reference to this LSTMParams object
111      */
set_peephole_params(T * cell_to_forget_weights,T * cell_to_output_weights)112     LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
113     {
114         _cell_to_forget_weights = cell_to_forget_weights;
115         _cell_to_output_weights = cell_to_output_weights;
116         _has_peephole_opt       = true;
117         return *this;
118     }
119     /** Set layer normalization tensor parameters.
120      *
121      * @param[in] input_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
122      * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
123      * @param[in] cell_layer_norm_weights   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
124      * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125      *
126      * @return Reference to this LSTMParams object
127      */
set_layer_normalization_params(T * input_layer_norm_weights,T * forget_layer_norm_weights,T * cell_layer_norm_weights,T * output_layer_norm_weights)128     LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
129                                                T *cell_layer_norm_weights, T *output_layer_norm_weights)
130     {
131         _input_layer_norm_weights  = input_layer_norm_weights;
132         _forget_layer_norm_weights = forget_layer_norm_weights;
133         _cell_layer_norm_weights   = cell_layer_norm_weights;
134         _output_layer_norm_weights = output_layer_norm_weights;
135         _use_layer_norm            = true;
136         return *this;
137     }
138 
139     /** Set cell clip value.
140      *
141      * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
142      *
143      * @return Reference to this LSTMParams object
144      */
set_cell_clip_params(float cell_clip)145     LSTMParams &set_cell_clip_params(float cell_clip)
146     {
147         _cell_clip = cell_clip;
148         return *this;
149     }
150 
151     /** Set projection clip value.
152      *
153      * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
154      *
155      * @return Reference to this LSTMParams object
156      */
set_projection_clip_params(float projection_clip)157     LSTMParams &set_projection_clip_params(float projection_clip)
158     {
159         _projection_clip = projection_clip;
160         return *this;
161     }
162 
163     /** Set scale of the intermediate results of matmul of each layer parameters.
164      *
165      * @param[in] input_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
166      * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
167      * @param[in] cell_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
168      * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
169      *
170      * @return Reference to this LSTMParams object
171      */
set_matmul_scale_params(float input_intermediate_scale,float forget_intermediate_scale,float cell_intermediate_scale,float output_intermediate_scale)172     LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
173     {
174         _input_intermediate_scale  = input_intermediate_scale;
175         _forget_intermediate_scale = forget_intermediate_scale;
176         _cell_intermediate_scale   = cell_intermediate_scale;
177         _output_intermediate_scale = output_intermediate_scale;
178         return *this;
179     }
180 
181     /** Set hidden state zero and scale parameters.
182      *
183      * @param[in] hidden_state_zero  The zero point of the hidden state.
184      * @param[in] hidden_state_scale The scale of the hidden state.
185      *
186      * @return Reference to this LSTMParams object
187      */
set_hidden_state_params(int32_t hidden_state_zero,float hidden_state_scale)188     LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
189     {
190         _hidden_state_zero  = hidden_state_zero;
191         _hidden_state_scale = hidden_state_scale;
192         return *this;
193     }
194 
input_to_input_weights()195     const T *input_to_input_weights() const
196     {
197         return _input_to_input_weights;
198     }
199 
recurrent_to_input_weights()200     const T *recurrent_to_input_weights() const
201     {
202         return _recurrent_to_input_weights;
203     }
204 
cell_to_input_weights()205     T *cell_to_input_weights() const
206     {
207         return _cell_to_input_weights;
208     }
209 
input_gate_bias()210     const T *input_gate_bias() const
211     {
212         return _input_gate_bias;
213     }
214 
cell_to_forget_weights()215     T *cell_to_forget_weights() const
216     {
217         return _cell_to_forget_weights;
218     }
219 
cell_to_output_weights()220     T *cell_to_output_weights() const
221     {
222         return _cell_to_output_weights;
223     }
224 
projection_weights()225     const T *projection_weights() const
226     {
227         return _projection_weights;
228     }
229 
projection_bias()230     const T *projection_bias() const
231     {
232         return _projection_bias;
233     }
234 
input_layer_norm_weights()235     T *input_layer_norm_weights() const
236     {
237         return _input_layer_norm_weights;
238     }
239 
forget_layer_norm_weights()240     T *forget_layer_norm_weights() const
241     {
242         return _forget_layer_norm_weights;
243     }
244 
cell_layer_norm_weights()245     T *cell_layer_norm_weights() const
246     {
247         return _cell_layer_norm_weights;
248     }
249 
output_layer_norm_weights()250     T *output_layer_norm_weights() const
251     {
252         return _output_layer_norm_weights;
253     }
254 
cell_clip()255     float cell_clip() const
256     {
257         return _cell_clip;
258     }
259 
projection_clip()260     float projection_clip() const
261     {
262         return _projection_clip;
263     }
264 
input_intermediate_scale()265     float input_intermediate_scale() const
266     {
267         return _input_intermediate_scale;
268     }
269 
forget_intermediate_scale()270     float forget_intermediate_scale() const
271     {
272         return _forget_intermediate_scale;
273     }
274 
cell_intermediate_scale()275     float cell_intermediate_scale() const
276     {
277         return _cell_intermediate_scale;
278     }
279 
output_intermediate_scale()280     float output_intermediate_scale() const
281     {
282         return _output_intermediate_scale;
283     }
284 
hidden_state_zero()285     int32_t hidden_state_zero() const
286     {
287         return _hidden_state_zero;
288     }
289 
hidden_state_scale()290     float hidden_state_scale() const
291     {
292         return _hidden_state_scale;
293     }
294 
has_peephole_opt()295     bool has_peephole_opt() const
296     {
297         return _has_peephole_opt;
298     }
299 
has_projection()300     bool has_projection() const
301     {
302         return _has_projection;
303     }
304 
has_cifg_opt()305     bool has_cifg_opt() const
306     {
307         return _has_cifg_opt;
308     }
309 
use_layer_norm()310     bool use_layer_norm() const
311     {
312         return _use_layer_norm;
313     }
314 
315 private:
316     const T *_input_to_input_weights;
317     const T *_recurrent_to_input_weights;
318     T       *_cell_to_input_weights;
319     const T *_input_gate_bias;
320     T       *_cell_to_forget_weights;
321     T       *_cell_to_output_weights;
322     const T *_projection_weights;
323     const T *_projection_bias;
324     T       *_input_layer_norm_weights;
325     T       *_forget_layer_norm_weights;
326     T       *_cell_layer_norm_weights;
327     T       *_output_layer_norm_weights;
328     float    _cell_clip;
329     float    _projection_clip;
330     float    _input_intermediate_scale;
331     float    _forget_intermediate_scale;
332     float    _cell_intermediate_scale;
333     float    _output_intermediate_scale;
334     int32_t  _hidden_state_zero;
335     float    _hidden_state_scale;
336     bool     _has_peephole_opt;
337     bool     _has_projection;
338     bool     _has_cifg_opt;
339     bool     _use_layer_norm;
340 };
341 }
342 #endif /*ARM_COMPUTE_LSTMPARAMS_H */
343