1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "FullyConnected.h"
18
19 #include <vector>
20
21 #include "OperationsValidationUtils.h"
22
23 namespace android::nn {
24 namespace fully_connected {
25
validateShapes(const Shape & input,const Shape & weights,const Shape & bias,Shape * output)26 bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias, Shape* output) {
27 // Check all the parameters of tensor match within themselves and match the
28 // input configuration.
29 NN_RET_CHECK(weights.type == input.type);
30 if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
31 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
32 NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
33 } else {
34 NN_RET_CHECK(bias.type == input.type);
35 }
36 // The Tensorflow fully connected layer specification says that input should
37 // be of at least rank 2, so we check. Tflite doesn't check.
38 NN_RET_CHECK_GE(getNumberOfDimensions(input), 2u);
39 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4u);
40 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2u);
41 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1u);
42 uint32_t input_n_elements = getNumberOfElements(input);
43 uint32_t num_units = getSizeOfDimension(weights, 0u);
44 uint32_t input_size = getSizeOfDimension(weights, 1u);
45 uint32_t bias_len = getSizeOfDimension(bias, 0u);
46 uint32_t batch_size = 0;
47 if (input_size != 0) {
48 NN_RET_CHECK_EQ(input_n_elements % input_size, 0u);
49 batch_size = input_n_elements / input_size;
50 }
51 if (num_units != 0 && bias_len != 0) {
52 NN_RET_CHECK_EQ(bias_len, num_units);
53 }
54 if (output != nullptr) {
55 // Only batch_size can be 0.
56 NN_RET_CHECK_GT(num_units, 0u);
57 NN_RET_CHECK_GT(input_size, 0u);
58 output->type = input.type;
59 output->dimensions = {batch_size, num_units};
60 }
61 return true;
62 }
63
validate(const IOperationValidationContext * context)64 Result<Version> validate(const IOperationValidationContext* context) {
65 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
66 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
67 auto inputType = context->getInputType(kInputTensor);
68 std::vector<OperandType> inExpectedTypes;
69 std::vector<OperandType> outExpectedTypes;
70 auto minSupportedVersion = kVersionFeatureLevel1;
71 if (inputType == OperandType::TENSOR_FLOAT32) {
72 minSupportedVersion = kVersionFeatureLevel1;
73 inExpectedTypes = {
74 OperandType::TENSOR_FLOAT32,
75 OperandType::TENSOR_FLOAT32,
76 OperandType::TENSOR_FLOAT32,
77 OperandType::INT32,
78 };
79 } else if (inputType == OperandType::TENSOR_FLOAT16) {
80 minSupportedVersion = kVersionFeatureLevel3;
81 inExpectedTypes = {
82 OperandType::TENSOR_FLOAT16,
83 OperandType::TENSOR_FLOAT16,
84 OperandType::TENSOR_FLOAT16,
85 OperandType::INT32,
86 };
87 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
88 // NeuralNetworks.h specifies that ANEURALNETWORKS_FULLY_CONNECTED's output must
89 // meet "outputScale > inputScale * weightsScale" for the operand type
90 // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29.
91 const float inputScale = context->getInputShape(kInputTensor).scale;
92 const float weightsScale = context->getInputShape(kWeightsTensor).scale;
93 const float outputScale = context->getOutputShape(kOutputTensor).scale;
94 bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
95
96 if (!meetsQuantizedScaleConstraintBeforeV1_2) {
97 minSupportedVersion = kVersionFeatureLevel3;
98 } else {
99 minSupportedVersion = kVersionFeatureLevel1;
100 }
101
102 inExpectedTypes = {
103 OperandType::TENSOR_QUANT8_ASYMM,
104 OperandType::TENSOR_QUANT8_ASYMM,
105 OperandType::TENSOR_INT32,
106 OperandType::INT32,
107 };
108 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
109 minSupportedVersion = kVersionFeatureLevel4;
110
111 inExpectedTypes = {
112 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
113 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
114 OperandType::TENSOR_INT32,
115 OperandType::INT32,
116 };
117 } else {
118 NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
119 }
120 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
121 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
122
123 Shape input = context->getInputShape(kInputTensor);
124 Shape weights = context->getInputShape(kWeightsTensor);
125 Shape bias = context->getInputShape(kBiasTensor);
126 if (hasKnownRank(input) && hasKnownRank(weights) && hasKnownRank(bias)) {
127 NN_RET_CHECK(validateShapes(input, weights, bias));
128 }
129
130 return minSupportedVersion;
131 }
132
133 } // namespace fully_connected
134
135 NN_DEFINE_VALIDATION_FUNCTION(FULLY_CONNECTED, fully_connected::validate);
136
137 } // namespace android::nn
138