xref: /aosp_15_r20/external/libaom/av1/encoder/ml.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker 
12*77c1e3ccSAndroid Build Coastguard Worker #ifndef AOM_AV1_ENCODER_ML_H_
13*77c1e3ccSAndroid Build Coastguard Worker #define AOM_AV1_ENCODER_ML_H_
14*77c1e3ccSAndroid Build Coastguard Worker 
15*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus
16*77c1e3ccSAndroid Build Coastguard Worker extern "C" {
17*77c1e3ccSAndroid Build Coastguard Worker #endif
18*77c1e3ccSAndroid Build Coastguard Worker 
19*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
20*77c1e3ccSAndroid Build Coastguard Worker 
21*77c1e3ccSAndroid Build Coastguard Worker #define NN_MAX_HIDDEN_LAYERS 10
22*77c1e3ccSAndroid Build Coastguard Worker #define NN_MAX_NODES_PER_LAYER 128
23*77c1e3ccSAndroid Build Coastguard Worker 
24*77c1e3ccSAndroid Build Coastguard Worker struct NN_CONFIG {
25*77c1e3ccSAndroid Build Coastguard Worker   int num_inputs;         // Number of input nodes, i.e. features.
26*77c1e3ccSAndroid Build Coastguard Worker   int num_outputs;        // Number of output nodes.
27*77c1e3ccSAndroid Build Coastguard Worker   int num_hidden_layers;  // Number of hidden layers, maximum 10.
28*77c1e3ccSAndroid Build Coastguard Worker   // Number of nodes for each hidden layer.
29*77c1e3ccSAndroid Build Coastguard Worker   int num_hidden_nodes[NN_MAX_HIDDEN_LAYERS];
30*77c1e3ccSAndroid Build Coastguard Worker   // Weight parameters, indexed by layer.
31*77c1e3ccSAndroid Build Coastguard Worker   const float *weights[NN_MAX_HIDDEN_LAYERS + 1];
32*77c1e3ccSAndroid Build Coastguard Worker   // Bias parameters, indexed by layer.
33*77c1e3ccSAndroid Build Coastguard Worker   const float *bias[NN_MAX_HIDDEN_LAYERS + 1];
34*77c1e3ccSAndroid Build Coastguard Worker };
35*77c1e3ccSAndroid Build Coastguard Worker // Typedef from struct NN_CONFIG to NN_CONFIG is in rtcd_defs
36*77c1e3ccSAndroid Build Coastguard Worker 
37*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_NN_V2
38*77c1e3ccSAndroid Build Coastguard Worker // Fully-connectedly layer configuration
39*77c1e3ccSAndroid Build Coastguard Worker struct FC_LAYER {
40*77c1e3ccSAndroid Build Coastguard Worker   const int num_inputs;   // Number of input nodes, i.e. features.
41*77c1e3ccSAndroid Build Coastguard Worker   const int num_outputs;  // Number of output nodes.
42*77c1e3ccSAndroid Build Coastguard Worker 
43*77c1e3ccSAndroid Build Coastguard Worker   float *weights;               // Weight parameters.
44*77c1e3ccSAndroid Build Coastguard Worker   float *bias;                  // Bias parameters.
45*77c1e3ccSAndroid Build Coastguard Worker   const ACTIVATION activation;  // Activation function.
46*77c1e3ccSAndroid Build Coastguard Worker 
47*77c1e3ccSAndroid Build Coastguard Worker   float *output;  // The output array.
48*77c1e3ccSAndroid Build Coastguard Worker   float *dY;      // Gradient of outputs
49*77c1e3ccSAndroid Build Coastguard Worker   float *dW;      // Gradient of weights.
50*77c1e3ccSAndroid Build Coastguard Worker   float *db;      // Gradient of bias
51*77c1e3ccSAndroid Build Coastguard Worker };
52*77c1e3ccSAndroid Build Coastguard Worker 
53*77c1e3ccSAndroid Build Coastguard Worker // NN configure structure V2
54*77c1e3ccSAndroid Build Coastguard Worker struct NN_CONFIG_V2 {
55*77c1e3ccSAndroid Build Coastguard Worker   const int num_hidden_layers;  // Number of hidden layers, max = 10.
56*77c1e3ccSAndroid Build Coastguard Worker   FC_LAYER layer[NN_MAX_HIDDEN_LAYERS + 1];  // The layer array
57*77c1e3ccSAndroid Build Coastguard Worker   const int num_logits;                      // Number of output nodes.
58*77c1e3ccSAndroid Build Coastguard Worker   float *logits;    // Raw prediction (same as output of final layer)
59*77c1e3ccSAndroid Build Coastguard Worker   const LOSS loss;  // Loss function
60*77c1e3ccSAndroid Build Coastguard Worker };
61*77c1e3ccSAndroid Build Coastguard Worker 
62*77c1e3ccSAndroid Build Coastguard Worker // Calculate prediction based on the given input features and neural net config.
63*77c1e3ccSAndroid Build Coastguard Worker // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
64*77c1e3ccSAndroid Build Coastguard Worker // layer.
65*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_predict_v2(const float *features, NN_CONFIG_V2 *nn_config,
66*77c1e3ccSAndroid Build Coastguard Worker                        int reduce_prec, float *output);
67*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_NN_V2
68*77c1e3ccSAndroid Build Coastguard Worker 
69*77c1e3ccSAndroid Build Coastguard Worker // Applies the softmax normalization function to the input
70*77c1e3ccSAndroid Build Coastguard Worker // to get a valid probability distribution in the output:
71*77c1e3ccSAndroid Build Coastguard Worker // output[i] = exp(input[i]) / sum_{k \in [0,n)}(exp(input[k]))
72*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_softmax(const float *input, float *output, int n);
73*77c1e3ccSAndroid Build Coastguard Worker 
74*77c1e3ccSAndroid Build Coastguard Worker // A faster but less accurate version of av1_nn_softmax(input, output, 16)
75*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_fast_softmax_16_c(const float *input, float *output);
76*77c1e3ccSAndroid Build Coastguard Worker 
77*77c1e3ccSAndroid Build Coastguard Worker // Applies a precision reduction to output of av1_nn_predict to prevent
78*77c1e3ccSAndroid Build Coastguard Worker // mismatches between C and SIMD implementations.
79*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_output_prec_reduce(float *const output, int num_output);
80*77c1e3ccSAndroid Build Coastguard Worker 
81*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus
82*77c1e3ccSAndroid Build Coastguard Worker }  // extern "C"
83*77c1e3ccSAndroid Build Coastguard Worker #endif
84*77c1e3ccSAndroid Build Coastguard Worker 
85*77c1e3ccSAndroid Build Coastguard Worker #endif  // AOM_AV1_ENCODER_ML_H_
86