xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/Support/MathExtras.h"
23 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/Matchers.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 
29 using namespace mlir;
30 using namespace mlir::quantfork;
31 
32 using mlir::quant::QuantizedType;
33 
34 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsDialect.cc.inc"
35 
initialize()36 void QuantizationForkDialect::initialize() {
37   addOperations<
38 #define GET_OP_LIST
39 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc.inc"
40       >();
41 }
42 
fold(ArrayRef<Attribute> operands)43 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
44   // Matches x -> [scast -> scast] -> y, replacing the second scast with the
45   // value of x if the casts invert each other.
46   auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
47   if (!srcScastOp || srcScastOp.getArg().getType() != getType())
48     return OpFoldResult();
49   return srcScastOp.getArg();
50 }
51 
52 /// The quantization specification should match the expressed type.
isValidQuantizationSpec(Attribute quantSpec,Type expressed)53 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
54   if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
55     Type spec = typeAttr.getValue();
56     if (spec.isa<TensorType, VectorType>()) return false;
57 
58     // The spec should be either a quantized type which is compatible to the
59     // expressed type, or a primitive type which is as same as the
60     // (element type of) the expressed type.
61     if (auto quantizedType = spec.dyn_cast<QuantizedType>())
62       return quantizedType.isCompatibleExpressedType(expressed);
63 
64     if (auto tensorType = expressed.dyn_cast<TensorType>())
65       return spec == tensorType.getElementType();
66 
67     if (auto vectorType = expressed.dyn_cast<VectorType>())
68       return spec == vectorType.getElementType();
69   }
70   return false;
71 }
72 
verify()73 LogicalResult QuantizeRegionOp::verify() {
74   // There are specifications for both inputs and outputs.
75   if (getNumOperands() != getInputSpecs().size() ||
76       getNumResults() != getOutputSpecs().size())
77     return emitOpError(
78         "has unmatched operands/results number and spec attributes number");
79 
80   // Verify that quantization specifications are valid.
81   for (auto input : llvm::zip(getOperandTypes(), getInputSpecs())) {
82     Type inputType = std::get<0>(input);
83     Attribute inputSpec = std::get<1>(input);
84     if (!isValidQuantizationSpec(inputSpec, inputType)) {
85       return emitOpError() << "has incompatible specification " << inputSpec
86                            << " and input type " << inputType;
87     }
88   }
89 
90   for (auto result : llvm::zip(getResultTypes(), getOutputSpecs())) {
91     Type outputType = std::get<0>(result);
92     Attribute outputSpec = std::get<1>(result);
93     if (!isValidQuantizationSpec(outputSpec, outputType)) {
94       return emitOpError() << "has incompatible specification " << outputSpec
95                            << " and output type " << outputType;
96     }
97   }
98   return success();
99 }
100 
verify()101 LogicalResult StatisticsOp::verify() {
102   auto tensorArg = getArg().getType().dyn_cast<TensorType>();
103   if (!tensorArg) return emitOpError("arg needs to be tensor type.");
104 
105   // Verify layerStats attribute.
106   {
107     auto layerStatsType = getLayerStats().getType();
108     if (!layerStatsType.getElementType().isa<FloatType>()) {
109       return emitOpError("layerStats must have a floating point element type");
110     }
111     if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
112       return emitOpError("layerStats must have shape [2]");
113     }
114   }
115   // Verify axisStats (optional) attribute.
116   if (getAxisStats()) {
117     if (!getAxis()) return emitOpError("axis must be specified for axisStats");
118 
119     auto shape = tensorArg.getShape();
120     auto argSliceSize =
121         std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1,
122                         std::multiplies<int64_t>());
123 
124     auto axisStatsType = getAxisStats()->getType();
125     if (!axisStatsType.getElementType().isa<FloatType>()) {
126       return emitOpError("axisStats must have a floating point element type");
127     }
128     if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 ||
129         axisStatsType.getDimSize(0) != argSliceSize) {
130       return emitOpError(
131           "axisStats must have shape [N,2] "
132           "where N = the slice size defined by the axis dim");
133     }
134   }
135   return success();
136 }
137 
138 #define GET_OP_CLASSES
139 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc.inc"
140