xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/constant_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
17 
18 #include <string>
19 
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/platform/status.h"
27 
28 namespace mlir {
29 namespace TFL {
30 
CreateConstOpWithSingleValue(PatternRewriter * rewriter,Location loc,ShapedType shaped_type,int value)31 stream_executor::port::StatusOr<arith::ConstantOp> CreateConstOpWithSingleValue(
32     PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
33     int value) {
34   Type element_type = shaped_type.getElementType();
35   ShapedType scalar_type = RankedTensorType::get({}, element_type);
36   Attribute attr;
37   if (element_type.isF16()) {
38     auto floatType = mlir::FloatType::getF16(element_type.getContext());
39     auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
40     std::vector<Attribute> floatValues({floatAttr});
41     attr = DenseElementsAttr::get(scalar_type, floatValues);
42   } else if (element_type.isBF16()) {
43     auto floatType = mlir::FloatType::getBF16(element_type.getContext());
44     auto floatAttr = mlir::FloatAttr::get(floatType, static_cast<float>(value));
45     std::vector<Attribute> floatValues({floatAttr});
46     attr = DenseElementsAttr::get(scalar_type, floatValues);
47   } else if (element_type.isF32()) {
48     attr =
49         DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
50   } else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
51     auto etype = complex_type.getElementType();
52     if (etype.isF32()) {
53       tensorflow::TensorProto repr;
54       repr.set_dtype(tensorflow::DT_COMPLEX64);
55 
56       tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
57       shape->set_unknown_rank(false);
58       shape->add_dim()->set_size(int64_t{1});
59       std::string content;
60       auto complex_value = std::complex<float>(static_cast<float>(value), 0.0f);
61       content.assign(reinterpret_cast<const char*>(&complex_value),
62                      sizeof(complex_value));
63       repr.set_tensor_content(content);
64       std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
65 
66       attr = mlir::TF::TensorProtoAttr::get(scalar_type, mangled);
67     } else {
68       return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
69                                 "Unsupported type");
70     }
71   } else if (auto itype = element_type.dyn_cast<mlir::IntegerType>()) {
72     if (element_type.isSignedInteger()) {
73       switch (itype.getWidth()) {
74         case 8:
75           attr = DenseElementsAttr::get<int8_t>(scalar_type,
76                                                 static_cast<int8_t>(value));
77           break;
78         case 16:
79           attr = DenseElementsAttr::get<int16_t>(scalar_type,
80                                                  static_cast<int16_t>(value));
81           break;
82         case 32:
83           attr = DenseElementsAttr::get<int32_t>(scalar_type,
84                                                  static_cast<int32_t>(value));
85           break;
86         case 64:
87           attr = DenseElementsAttr::get<int64_t>(scalar_type,
88                                                  static_cast<int64_t>(value));
89           break;
90         default:
91           return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
92                                     "Unsupported type");
93       }
94     } else {
95       switch (itype.getWidth()) {
96         case 8:
97           attr = DenseElementsAttr::get<uint8_t>(scalar_type,
98                                                  static_cast<uint8_t>(value));
99           break;
100         case 16:
101           attr = DenseElementsAttr::get<uint16_t>(scalar_type,
102                                                   static_cast<uint16_t>(value));
103           break;
104         case 32:
105           attr = DenseElementsAttr::get<uint32_t>(scalar_type,
106                                                   static_cast<uint32_t>(value));
107           break;
108         case 64:
109           attr = DenseElementsAttr::get<uint64_t>(scalar_type,
110                                                   static_cast<uint64_t>(value));
111           break;
112         default:
113           return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
114                                     "Unsupported type");
115       }
116     }
117   } else {
118     return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
119                               "Unsupported type");
120   }
121   return rewriter->create<arith::ConstantOp>(loc, scalar_type, attr);
122 }
123 
124 }  // namespace TFL
125 }  // namespace mlir
126