1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
17
18 #include <numeric>
19
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21
22 using namespace mlir;
23 using namespace mlir::quantfork;
24
isQuantizablePrimitiveType(Type inputType)25 static bool isQuantizablePrimitiveType(Type inputType) {
26 return inputType.isa<FloatType>();
27 }
28
forInputType(Type inputType)29 ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType(
30 Type inputType) {
31 if (inputType.isa<TensorType, VectorType>()) {
32 Type elementType = inputType.cast<ShapedType>().getElementType();
33 if (!isQuantizablePrimitiveType(elementType))
34 return ExpressedToQuantizedConverter{inputType, nullptr};
35 return ExpressedToQuantizedConverter{inputType, elementType};
36 }
37 // Supported primitive type (which just is the expressed type).
38 if (isQuantizablePrimitiveType(inputType))
39 return ExpressedToQuantizedConverter{inputType, inputType};
40 // Unsupported.
41 return ExpressedToQuantizedConverter{inputType, nullptr};
42 }
43
convert(quant::QuantizedType elementalType) const44 Type ExpressedToQuantizedConverter::convert(
45 quant::QuantizedType elementalType) const {
46 assert(expressedType && "convert() on unsupported conversion");
47 if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
48 return RankedTensorType::get(tensorType.getShape(), elementalType);
49 if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
50 return UnrankedTensorType::get(elementalType);
51 if (auto vectorType = inputType.dyn_cast<VectorType>())
52 return VectorType::get(vectorType.getShape(), elementalType);
53
54 // If the expressed types match, just use the new elemental type.
55 if (elementalType.getExpressedType() == expressedType) return elementalType;
56 // Unsupported.
57 return nullptr;
58 }
59
convert(Attribute realValue)60 ElementsAttr UniformQuantizedPerAxisValueConverter::convert(
61 Attribute realValue) {
62 if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
63 return convert(attr);
64 }
65 // TODO: handles sparse elements attribute
66 return nullptr;
67 }
68
convert(DenseFPElementsAttr attr)69 DenseElementsAttr UniformQuantizedPerAxisValueConverter::convert(
70 DenseFPElementsAttr attr) {
71 // Creates the converter for each chunk. Normally the size of the
72 // quantization dim is 3, so we can cache all the converters.
73 ShapedType type = attr.getType();
74 size_t dimSize = type.getDimSize(quantizationDim);
75 if (dimSize != scales.size()) {
76 return {};
77 }
78 SmallVector<UniformQuantizedValueConverter, 4> converters;
79 converters.reserve(dimSize);
80 for (int i = 0, e = dimSize; i != e; ++i) {
81 converters.push_back(getPerChunkConverter(i));
82 }
83
84 // Scan the elements of the dense elements attributes and quantize them by
85 // using the right quantization parameters.
86 int64_t flattenIndex = 0;
87 auto shape = type.getShape();
88 int64_t chunkSize =
89 std::accumulate(std::next(shape.begin(), quantizationDim + 1),
90 shape.end(), 1, std::multiplies<int64_t>());
91 Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
92 return attr.mapValues(newElementType, [&](const APFloat &old) {
93 int chunkIndex = (flattenIndex++) / chunkSize;
94 return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
95 });
96 }
97