1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
17
18 #include <string>
19
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/platform/status.h"
27
28 namespace mlir {
29 namespace TFL {
30
CreateConstOpWithSingleValue(PatternRewriter * rewriter,Location loc,ShapedType shaped_type,int value)31 stream_executor::port::StatusOr<arith::ConstantOp> CreateConstOpWithSingleValue(
32 PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
33 int value) {
34 Type element_type = shaped_type.getElementType();
35 ShapedType scalar_type = RankedTensorType::get({}, element_type);
36 Attribute attr;
37 if (element_type.isF16()) {
38 auto floatType = mlir::FloatType::getF16(element_type.getContext());
39 auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
40 std::vector<Attribute> floatValues({floatAttr});
41 attr = DenseElementsAttr::get(scalar_type, floatValues);
42 } else if (element_type.isBF16()) {
43 auto floatType = mlir::FloatType::getBF16(element_type.getContext());
44 auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
45 std::vector<Attribute> floatValues({floatAttr});
46 attr = DenseElementsAttr::get(scalar_type, floatValues);
47 } else if (element_type.isF32()) {
48 attr =
49 DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
50 } else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
51 auto etype = complex_type.getElementType();
52 if (etype.isF32()) {
53 tensorflow::TensorProto repr;
54 repr.set_dtype(tensorflow::DT_COMPLEX64);
55
56 tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
57 shape->set_unknown_rank(false);
58 shape->add_dim()->set_size(int64_t{1});
59 std::string content;
60 auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
61 content.assign(reinterpret_cast<const char*>(&complex_value),
62 sizeof(complex_value));
63 repr.set_tensor_content(content);
64 std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
65
66 attr = mlir::TF::TensorProtoAttr::get(scalar_type, mangled);
67 } else {
68 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
69 "Unsupported type");
70 }
71 } else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
72 if (element_type.isSignedInteger()) {
73 switch (itype.getWidth()) {
74 case 8:
75 attr = DenseElementsAttr::get<int8_t>(scalar_type,
76 static_cast<int8_t>(value));
77 break;
78 case 16:
79 attr = DenseElementsAttr::get<int16_t>(scalar_type,
80 static_cast<int16_t>(value));
81 break;
82 case 32:
83 attr = DenseElementsAttr::get<int32_t>(scalar_type,
84 static_cast<int32_t>(value));
85 break;
86 case 64:
87 attr = DenseElementsAttr::get<int64_t>(scalar_type,
88 static_cast<int64_t>(value));
89 break;
90 default:
91 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
92 "Unsupported type");
93 }
94 } else {
95 switch (itype.getWidth()) {
96 case 8:
97 attr = DenseElementsAttr::get<uint8_t>(scalar_type,
98 static_cast<uint8_t>(value));
99 break;
100 case 16:
101 attr = DenseElementsAttr::get<uint16_t>(scalar_type,
102 static_cast<uint16_t>(value));
103 break;
104 case 32:
105 attr = DenseElementsAttr::get<uint32_t>(scalar_type,
106 static_cast<uint32_t>(value));
107 break;
108 case 64:
109 attr = DenseElementsAttr::get<uint64_t>(scalar_type,
110 static_cast<uint64_t>(value));
111 break;
112 default:
113 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
114 "Unsupported type");
115 }
116 }
117 } else {
118 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
119 "Unsupported type");
120 }
121 return rewriter->create<arith::ConstantOp>(loc, scalar_type, attr);
122 }
123
124 } // namespace TFL
125 } // namespace mlir
126