1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h"
17
18 #include "mlir/Dialect/Traits.h" // from @llvm-project
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/Matchers.h" // from @llvm-project
21 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
22
23 namespace mlir {
24 namespace TF {
25
26 class IdentityOp;
27 class IdentityNOp;
28
29 // Returns the RankedTensorType for the given operand. TensorFlow constant ops
30 // may have non-static shape because the shape is not propagated during constant
31 // folding. If the defining op for the given operand is a constant op, this
32 // routine uses the constant op's attribute to get the actual shape.
GetRankedTensorTypeForOperand(Value operand)33 RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
34 DenseElementsAttr attr;
35 if (matchPattern(operand, m_Constant(&attr))) {
36 return attr.getType().dyn_cast<RankedTensorType>();
37 }
38 return operand.getType().dyn_cast<RankedTensorType>();
39 }
40
41 // Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
42 // `incompatible_shape_error` is true, reports error if `x` and `y` has
43 // incompatible shapes. Otherwise, returns a tensor type with unknown rank.
DeduceEqualCmpOpType(Builder * builder,Location loc,Value x,Value y,BoolAttr incompatible_shape_error)44 Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, Value y,
45 BoolAttr incompatible_shape_error) {
46 auto result_type =
47 OpTrait::util::getBroadcastedType(x.getType(), y.getType());
48 if (!result_type) {
49 if (incompatible_shape_error.getValue()) {
50 mlir::emitError(loc, "non-broadcastable operands");
51 } else {
52 return UnrankedTensorType::get(builder->getI1Type());
53 }
54 }
55
56 auto ranked_type = result_type.dyn_cast<RankedTensorType>();
57 if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type());
58
59 return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type());
60 }
61
InferReductionOpType(Value input,Value reduction_indices,BoolAttr keep_dims)62 Type InferReductionOpType(Value input, Value reduction_indices,
63 BoolAttr keep_dims) {
64 Type input_ty = input.getType();
65 Type element_ty = getElementTypeOrSelf(input_ty);
66
67 // Output type is unranked if input type is not ranked.
68 auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
69 if (!ranked_ty) return UnrankedTensorType::get(element_ty);
70 int64_t rank = ranked_ty.getRank();
71
72 DenseIntElementsAttr indices;
73 if (!matchPattern(reduction_indices, m_Constant(&indices))) {
74 // Output type is unranked if reduction indices are not constant and reduced
75 // dimensions are not kept.
76 if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty);
77
78 // Otherwise, output type has same rank as the input.
79 return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty);
80 }
81
82 int64_t num_reduce_dim = 0;
83 llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
84 for (const APInt &index : indices.getValues<APInt>()) {
85 int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
86 // Invalid input.
87 if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
88
89 if (!is_reduce_dim[dim]) {
90 is_reduce_dim[dim] = true;
91 num_reduce_dim++;
92 }
93 }
94
95 ArrayRef<int64_t> shape = ranked_ty.getShape();
96 SmallVector<int64_t, 4> out_shape;
97 out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim));
98 for (int64_t i = 0; i < rank; ++i) {
99 if (!is_reduce_dim[i])
100 out_shape.push_back(shape[i]);
101 else if (keep_dims.getValue())
102 out_shape.push_back(1);
103 }
104 return RankedTensorType::get(out_shape, element_ty);
105 }
106
107 // Verifies that the given types are cast compatible. If not, emits appropriate
108 // error for the given op. If mask_one_dim is set to true, then the types are
109 // allowed to have one mismatching dimension. Masking one of the dimensions is
110 // useful for ops like Concat that requires all ranked inputs to have the same
111 // rank and match dimension sizes for all but one of the dimensions.
VerifyTypesCompatibility(Operation::operand_type_range types,bool mask_one_dim,Operation * op)112 LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types,
113 bool mask_one_dim, Operation *op) {
114 constexpr int64_t kUninitialized = -1;
115 int64_t common_rank = kUninitialized;
116 llvm::SmallVector<int64_t, 4> common_dims;
117 int64_t dim_to_mask = kUninitialized;
118
119 // Initialize common_rank with rank of the first ranked type and verify that
120 // following ranked types have the same rank.
121 // Similarly, initialize each of the dimensions with the first type that has
122 // the dimension size available and verify that all following types have the
123 // same size for the dimension. However, if mask_one_dim is true, note down
124 // the dimension index on the first mismatch and ignore dimension at that
125 // index in following types.
126 for (Type ty : types) {
127 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
128 if (!ranked_ty) continue;
129
130 int64_t rank = ranked_ty.getRank();
131 if (common_rank == kUninitialized) {
132 common_rank = rank;
133 common_dims.resize(common_rank, kUninitialized);
134 } else if (common_rank != rank) {
135 return op->emitError()
136 << "operand type " << ranked_ty
137 << " is not compatible with preceding operands; expected rank: "
138 << common_rank;
139 }
140
141 for (int64_t i = 0, e = common_rank; i != e; i++) {
142 if (i == dim_to_mask) continue;
143
144 int64_t dim = ranked_ty.getDimSize(i);
145 if (dim == kUninitialized) continue;
146
147 int64_t &common_dim = common_dims[i];
148 if (common_dim == kUninitialized) {
149 common_dim = dim;
150 } else if (common_dim != dim) {
151 // If mask_one_dim is true, do not emit an error if this is the only
152 // dimension with mismatches. Note down the dimension to mask it from
153 // the following types.
154 if (mask_one_dim && dim_to_mask == kUninitialized) {
155 dim_to_mask = i;
156 continue;
157 }
158
159 return op->emitError() << "operand type " << ranked_ty
160 << " is not compatible with preceding operands; "
161 "expected dimension at index "
162 << i << ": " << common_dim;
163 }
164 }
165 }
166 return success();
167 }
168
169 } // namespace TF
170 } // namespace mlir
171