xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
17 
18 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
19 
20 namespace mlir {
21 namespace TFL {
22 namespace tac {
23 
NotTFLQuantDequantizeOp(Operation * op)24 bool NotTFLQuantDequantizeOp(Operation* op) {
25   if (!op) return false;
26   if (llvm::isa<TFL::QuantizeOp, TFL::DequantizeOp>(op)) return false;
27   return true;
28 }
29 
IsTerminatorOp(Operation * op)30 bool IsTerminatorOp(Operation* op) {
31   if (!op) return false;
32   return op->hasTrait<OpTrait::IsTerminator>();
33 }
34 
35 // Try to guess the inference type of the op.
GetInferenceType(Operation * op)36 InferenceType GetInferenceType(Operation* op) {
37   bool float_type_observed = false;
38   bool int8_type_observed = false;
39   bool uint8_type_observed = false;
40   for (auto& input : op->getOpOperands()) {
41     auto input_type = input.get().getType();
42     if (IsF32ShapedType(input_type)) {
43       float_type_observed = true;
44     } else if (IsQI8Type(input_type)) {
45       int8_type_observed = true;
46     } else if (IsQUI8Type(input_type)) {
47       uint8_type_observed = true;
48     }
49   }
50 
51   // We should not observe both uint8 & int8.
52   if (int8_type_observed && uint8_type_observed) return UNKNOWN;
53 
54   if (float_type_observed) {
55     if (int8_type_observed || uint8_type_observed) {
56       return HYBRID;
57     } else {
58       return FLOAT;
59     }
60   }
61 
62   if (int8_type_observed) {
63     return QUANTIZED_INT8;
64   }
65 
66   if (uint8_type_observed) {
67     return QUANTIZED_UINT8;
68   }
69 
70   // Default to float inference.
71   return FLOAT;
72 }
73 
74 }  // namespace tac
75 }  // namespace TFL
76 }  // namespace mlir
77