1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "FullyConnected.h"
18 
19 #include <vector>
20 
21 #include "OperationsValidationUtils.h"
22 
23 namespace android::nn {
24 namespace fully_connected {
25 
validateShapes(const Shape & input,const Shape & weights,const Shape & bias,Shape * output)26 bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias, Shape* output) {
27     // Check all the parameters of tensor match within themselves and match the
28     // input configuration.
29     NN_RET_CHECK(weights.type == input.type);
30     if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
31         input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
32         NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
33     } else {
34         NN_RET_CHECK(bias.type == input.type);
35     }
36     // The Tensorflow fully connected layer specification says that input should
37     // be of at least rank 2, so we check. Tflite doesn't check.
38     NN_RET_CHECK_GE(getNumberOfDimensions(input), 2u);
39     NN_RET_CHECK_LE(getNumberOfDimensions(input), 4u);
40     NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2u);
41     NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1u);
42     uint32_t input_n_elements = getNumberOfElements(input);
43     uint32_t num_units = getSizeOfDimension(weights, 0u);
44     uint32_t input_size = getSizeOfDimension(weights, 1u);
45     uint32_t bias_len = getSizeOfDimension(bias, 0u);
46     uint32_t batch_size = 0;
47     if (input_size != 0) {
48         NN_RET_CHECK_EQ(input_n_elements % input_size, 0u);
49         batch_size = input_n_elements / input_size;
50     }
51     if (num_units != 0 && bias_len != 0) {
52         NN_RET_CHECK_EQ(bias_len, num_units);
53     }
54     if (output != nullptr) {
55         // Only batch_size can be 0.
56         NN_RET_CHECK_GT(num_units, 0u);
57         NN_RET_CHECK_GT(input_size, 0u);
58         output->type = input.type;
59         output->dimensions = {batch_size, num_units};
60     }
61     return true;
62 }
63 
validate(const IOperationValidationContext * context)64 Result<Version> validate(const IOperationValidationContext* context) {
65     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
66     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
67     auto inputType = context->getInputType(kInputTensor);
68     std::vector<OperandType> inExpectedTypes;
69     std::vector<OperandType> outExpectedTypes;
70     auto minSupportedVersion = kVersionFeatureLevel1;
71     if (inputType == OperandType::TENSOR_FLOAT32) {
72         minSupportedVersion = kVersionFeatureLevel1;
73         inExpectedTypes = {
74                 OperandType::TENSOR_FLOAT32,
75                 OperandType::TENSOR_FLOAT32,
76                 OperandType::TENSOR_FLOAT32,
77                 OperandType::INT32,
78         };
79     } else if (inputType == OperandType::TENSOR_FLOAT16) {
80         minSupportedVersion = kVersionFeatureLevel3;
81         inExpectedTypes = {
82                 OperandType::TENSOR_FLOAT16,
83                 OperandType::TENSOR_FLOAT16,
84                 OperandType::TENSOR_FLOAT16,
85                 OperandType::INT32,
86         };
87     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
88         // NeuralNetworks.h specifies that ANEURALNETWORKS_FULLY_CONNECTED's output must
89         // meet "outputScale > inputScale * weightsScale" for the operand type
90         // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29.
91         const float inputScale = context->getInputShape(kInputTensor).scale;
92         const float weightsScale = context->getInputShape(kWeightsTensor).scale;
93         const float outputScale = context->getOutputShape(kOutputTensor).scale;
94         bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
95 
96         if (!meetsQuantizedScaleConstraintBeforeV1_2) {
97             minSupportedVersion = kVersionFeatureLevel3;
98         } else {
99             minSupportedVersion = kVersionFeatureLevel1;
100         }
101 
102         inExpectedTypes = {
103                 OperandType::TENSOR_QUANT8_ASYMM,
104                 OperandType::TENSOR_QUANT8_ASYMM,
105                 OperandType::TENSOR_INT32,
106                 OperandType::INT32,
107         };
108     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
109         minSupportedVersion = kVersionFeatureLevel4;
110 
111         inExpectedTypes = {
112                 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
113                 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
114                 OperandType::TENSOR_INT32,
115                 OperandType::INT32,
116         };
117     } else {
118         NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
119     }
120     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
121     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
122 
123     Shape input = context->getInputShape(kInputTensor);
124     Shape weights = context->getInputShape(kWeightsTensor);
125     Shape bias = context->getInputShape(kBiasTensor);
126     if (hasKnownRank(input) && hasKnownRank(weights) && hasKnownRank(bias)) {
127         NN_RET_CHECK(validateShapes(input, weights, bias));
128     }
129 
130     return minSupportedVersion;
131 }
132 
133 }  // namespace fully_connected
134 
135 NN_DEFINE_VALIDATION_FUNCTION(FULLY_CONNECTED, fully_connected::validate);
136 
137 }  // namespace android::nn
138