1*77c1e3ccSAndroid Build Coastguard Worker /* 2*77c1e3ccSAndroid Build Coastguard Worker * Copyright (c) 2016, Alliance for Open Media. All rights reserved. 3*77c1e3ccSAndroid Build Coastguard Worker * 4*77c1e3ccSAndroid Build Coastguard Worker * This source code is subject to the terms of the BSD 2 Clause License and 5*77c1e3ccSAndroid Build Coastguard Worker * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6*77c1e3ccSAndroid Build Coastguard Worker * was not distributed with this source code in the LICENSE file, you can 7*77c1e3ccSAndroid Build Coastguard Worker * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8*77c1e3ccSAndroid Build Coastguard Worker * Media Patent License 1.0 was not distributed with this source code in the 9*77c1e3ccSAndroid Build Coastguard Worker * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10*77c1e3ccSAndroid Build Coastguard Worker */ 11*77c1e3ccSAndroid Build Coastguard Worker 12*77c1e3ccSAndroid Build Coastguard Worker #ifndef AOM_AV1_ENCODER_ML_H_ 13*77c1e3ccSAndroid Build Coastguard Worker #define AOM_AV1_ENCODER_ML_H_ 14*77c1e3ccSAndroid Build Coastguard Worker 15*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus 16*77c1e3ccSAndroid Build Coastguard Worker extern "C" { 17*77c1e3ccSAndroid Build Coastguard Worker #endif 18*77c1e3ccSAndroid Build Coastguard Worker 19*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h" 20*77c1e3ccSAndroid Build Coastguard Worker 21*77c1e3ccSAndroid Build Coastguard Worker #define NN_MAX_HIDDEN_LAYERS 10 22*77c1e3ccSAndroid Build Coastguard Worker #define NN_MAX_NODES_PER_LAYER 128 23*77c1e3ccSAndroid Build Coastguard Worker 24*77c1e3ccSAndroid Build Coastguard Worker struct NN_CONFIG { 25*77c1e3ccSAndroid Build Coastguard Worker int num_inputs; // Number of input nodes, i.e. features. 26*77c1e3ccSAndroid Build Coastguard Worker int num_outputs; // Number of output nodes. 27*77c1e3ccSAndroid Build Coastguard Worker int num_hidden_layers; // Number of hidden layers, maximum 10. 28*77c1e3ccSAndroid Build Coastguard Worker // Number of nodes for each hidden layer. 29*77c1e3ccSAndroid Build Coastguard Worker int num_hidden_nodes[NN_MAX_HIDDEN_LAYERS]; 30*77c1e3ccSAndroid Build Coastguard Worker // Weight parameters, indexed by layer. 31*77c1e3ccSAndroid Build Coastguard Worker const float *weights[NN_MAX_HIDDEN_LAYERS + 1]; 32*77c1e3ccSAndroid Build Coastguard Worker // Bias parameters, indexed by layer. 33*77c1e3ccSAndroid Build Coastguard Worker const float *bias[NN_MAX_HIDDEN_LAYERS + 1]; 34*77c1e3ccSAndroid Build Coastguard Worker }; 35*77c1e3ccSAndroid Build Coastguard Worker // Typedef from struct NN_CONFIG to NN_CONFIG is in rtcd_defs 36*77c1e3ccSAndroid Build Coastguard Worker 37*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_NN_V2 38*77c1e3ccSAndroid Build Coastguard Worker // Fully-connectedly layer configuration 39*77c1e3ccSAndroid Build Coastguard Worker struct FC_LAYER { 40*77c1e3ccSAndroid Build Coastguard Worker const int num_inputs; // Number of input nodes, i.e. features. 41*77c1e3ccSAndroid Build Coastguard Worker const int num_outputs; // Number of output nodes. 42*77c1e3ccSAndroid Build Coastguard Worker 43*77c1e3ccSAndroid Build Coastguard Worker float *weights; // Weight parameters. 44*77c1e3ccSAndroid Build Coastguard Worker float *bias; // Bias parameters. 45*77c1e3ccSAndroid Build Coastguard Worker const ACTIVATION activation; // Activation function. 46*77c1e3ccSAndroid Build Coastguard Worker 47*77c1e3ccSAndroid Build Coastguard Worker float *output; // The output array. 48*77c1e3ccSAndroid Build Coastguard Worker float *dY; // Gradient of outputs 49*77c1e3ccSAndroid Build Coastguard Worker float *dW; // Gradient of weights. 50*77c1e3ccSAndroid Build Coastguard Worker float *db; // Gradient of bias 51*77c1e3ccSAndroid Build Coastguard Worker }; 52*77c1e3ccSAndroid Build Coastguard Worker 53*77c1e3ccSAndroid Build Coastguard Worker // NN configure structure V2 54*77c1e3ccSAndroid Build Coastguard Worker struct NN_CONFIG_V2 { 55*77c1e3ccSAndroid Build Coastguard Worker const int num_hidden_layers; // Number of hidden layers, max = 10. 56*77c1e3ccSAndroid Build Coastguard Worker FC_LAYER layer[NN_MAX_HIDDEN_LAYERS + 1]; // The layer array 57*77c1e3ccSAndroid Build Coastguard Worker const int num_logits; // Number of output nodes. 58*77c1e3ccSAndroid Build Coastguard Worker float *logits; // Raw prediction (same as output of final layer) 59*77c1e3ccSAndroid Build Coastguard Worker const LOSS loss; // Loss function 60*77c1e3ccSAndroid Build Coastguard Worker }; 61*77c1e3ccSAndroid Build Coastguard Worker 62*77c1e3ccSAndroid Build Coastguard Worker // Calculate prediction based on the given input features and neural net config. 63*77c1e3ccSAndroid Build Coastguard Worker // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden 64*77c1e3ccSAndroid Build Coastguard Worker // layer. 65*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_predict_v2(const float *features, NN_CONFIG_V2 *nn_config, 66*77c1e3ccSAndroid Build Coastguard Worker int reduce_prec, float *output); 67*77c1e3ccSAndroid Build Coastguard Worker #endif // CONFIG_NN_V2 68*77c1e3ccSAndroid Build Coastguard Worker 69*77c1e3ccSAndroid Build Coastguard Worker // Applies the softmax normalization function to the input 70*77c1e3ccSAndroid Build Coastguard Worker // to get a valid probability distribution in the output: 71*77c1e3ccSAndroid Build Coastguard Worker // output[i] = exp(input[i]) / sum_{k \in [0,n)}(exp(input[k])) 72*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_softmax(const float *input, float *output, int n); 73*77c1e3ccSAndroid Build Coastguard Worker 74*77c1e3ccSAndroid Build Coastguard Worker // A faster but less accurate version of av1_nn_softmax(input, output, 16) 75*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_fast_softmax_16_c(const float *input, float *output); 76*77c1e3ccSAndroid Build Coastguard Worker 77*77c1e3ccSAndroid Build Coastguard Worker // Applies a precision reduction to output of av1_nn_predict to prevent 78*77c1e3ccSAndroid Build Coastguard Worker // mismatches between C and SIMD implementations. 79*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_output_prec_reduce(float *const output, int num_output); 80*77c1e3ccSAndroid Build Coastguard Worker 81*77c1e3ccSAndroid Build Coastguard Worker #ifdef __cplusplus 82*77c1e3ccSAndroid Build Coastguard Worker } // extern "C" 83*77c1e3ccSAndroid Build Coastguard Worker #endif 84*77c1e3ccSAndroid Build Coastguard Worker 85*77c1e3ccSAndroid Build Coastguard Worker #endif // AOM_AV1_ENCODER_ML_H_ 86