1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/Twine.h"
22 #include "llvm/Support/MathExtras.h"
23 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/Matchers.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28
29 using namespace mlir;
30 using namespace mlir::quantfork;
31
32 using mlir::quant::QuantizedType;
33
34 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsDialect.cc.inc"
35
initialize()36 void QuantizationForkDialect::initialize() {
37 addOperations<
38 #define GET_OP_LIST
39 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc.inc"
40 >();
41 }
42
fold(ArrayRef<Attribute> operands)43 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
44 // Matches x -> [scast -> scast] -> y, replacing the second scast with the
45 // value of x if the casts invert each other.
46 auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
47 if (!srcScastOp || srcScastOp.getArg().getType() != getType())
48 return OpFoldResult();
49 return srcScastOp.getArg();
50 }
51
52 /// The quantization specification should match the expressed type.
isValidQuantizationSpec(Attribute quantSpec,Type expressed)53 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
54 if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
55 Type spec = typeAttr.getValue();
56 if (spec.isa<TensorType, VectorType>()) return false;
57
58 // The spec should be either a quantized type which is compatible to the
59 // expressed type, or a primitive type which is as same as the
60 // (element type of) the expressed type.
61 if (auto quantizedType = spec.dyn_cast<QuantizedType>())
62 return quantizedType.isCompatibleExpressedType(expressed);
63
64 if (auto tensorType = expressed.dyn_cast<TensorType>())
65 return spec == tensorType.getElementType();
66
67 if (auto vectorType = expressed.dyn_cast<VectorType>())
68 return spec == vectorType.getElementType();
69 }
70 return false;
71 }
72
verify()73 LogicalResult QuantizeRegionOp::verify() {
74 // There are specifications for both inputs and outputs.
75 if (getNumOperands() != getInputSpecs().size() ||
76 getNumResults() != getOutputSpecs().size())
77 return emitOpError(
78 "has unmatched operands/results number and spec attributes number");
79
80 // Verify that quantization specifications are valid.
81 for (auto input : llvm::zip(getOperandTypes(), getInputSpecs())) {
82 Type inputType = std::get<0>(input);
83 Attribute inputSpec = std::get<1>(input);
84 if (!isValidQuantizationSpec(inputSpec, inputType)) {
85 return emitOpError() << "has incompatible specification " << inputSpec
86 << " and input type " << inputType;
87 }
88 }
89
90 for (auto result : llvm::zip(getResultTypes(), getOutputSpecs())) {
91 Type outputType = std::get<0>(result);
92 Attribute outputSpec = std::get<1>(result);
93 if (!isValidQuantizationSpec(outputSpec, outputType)) {
94 return emitOpError() << "has incompatible specification " << outputSpec
95 << " and output type " << outputType;
96 }
97 }
98 return success();
99 }
100
verify()101 LogicalResult StatisticsOp::verify() {
102 auto tensorArg = getArg().getType().dyn_cast<TensorType>();
103 if (!tensorArg) return emitOpError("arg needs to be tensor type.");
104
105 // Verify layerStats attribute.
106 {
107 auto layerStatsType = getLayerStats().getType();
108 if (!layerStatsType.getElementType().isa<FloatType>()) {
109 return emitOpError("layerStats must have a floating point element type");
110 }
111 if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) {
112 return emitOpError("layerStats must have shape [2]");
113 }
114 }
115 // Verify axisStats (optional) attribute.
116 if (getAxisStats()) {
117 if (!getAxis()) return emitOpError("axis must be specified for axisStats");
118
119 auto shape = tensorArg.getShape();
120 auto argSliceSize =
121 std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1,
122 std::multiplies<int64_t>());
123
124 auto axisStatsType = getAxisStats()->getType();
125 if (!axisStatsType.getElementType().isa<FloatType>()) {
126 return emitOpError("axisStats must have a floating point element type");
127 }
128 if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 ||
129 axisStatsType.getDimSize(0) != argSliceSize) {
130 return emitOpError(
131 "axisStats must have shape [N,2] "
132 "where N = the slice size defined by the axis dim");
133 }
134 }
135 return success();
136 }
137
138 #define GET_OP_CLASSES
139 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc.inc"
140