1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H
17 #define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H
18 
19 #include <algorithm>
20 
21 #include "llvm/ADT/Sequence.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 
25 // Include order below matters.
26 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"
27 #define GET_ATTRDEF_CLASSES
28 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h.inc"
29 
30 namespace mlir {
31 namespace mhlo {
32 
33 // Forward declaration for a function declared in hlo_ops.h.
34 bool isCompatibleForMhloTypeInference(Type tp1, Type tp2);
35 
36 namespace OpTrait {
37 
38 template <typename ConcreteType>
39 class BroadcastingElementwise
40     : public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
41 
42 template <typename ConcreteType>
43 class PairwiseSameOperandAndResultType
44     : public mlir::OpTrait::TraitBase<ConcreteType,
45                                       PairwiseSameOperandAndResultType> {
46  public:
verifyTrait(Operation * op)47   static LogicalResult verifyTrait(Operation *op) {
48     const int numOperands = op->getNumOperands();
49     const int numResults = op->getNumResults();
50     if (numOperands != numResults) {
51       return op->emitOpError()
52              << "requires the same number of operands and results";
53     }
54 
55     for (int idx : llvm::seq<int>(0, numOperands)) {
56       if (op->getOperand(idx).getType() != op->getResult(idx).getType()) {
57         return op->emitOpError()
58                << "requires the same type for operand and result at index "
59                << idx;
60       }
61     }
62     return success();
63   }
64 };
65 
66 template <typename ConcreteType>
67 class CompatibleOperandsAndResultType
68     : public mlir::OpTrait::TraitBase<ConcreteType,
69                                       CompatibleOperandsAndResultType> {
70  public:
verifyTrait(Operation * op)71   static LogicalResult verifyTrait(Operation *op) {
72     Type expected;
73     if (op->getNumResults() != 0) expected = op->getResult(0).getType();
74     if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
75     if (!expected) return failure();
76 
77     auto typeMatch = [&](Type actual) {
78       return isCompatibleForMhloTypeInference(actual, expected);
79     };
80     auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) &&
81                     llvm::all_of(op->getResultTypes(), typeMatch);
82     if (!allMatch) {
83       return op->emitOpError(
84           "requires compatible types for all operands and results");
85     }
86 
87     return success(allMatch);
88   }
89 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)90   static LogicalResult inferReturnTypes(
91       MLIRContext *context, Optional<Location> location, ValueRange operands,
92       DictionaryAttr /*attributes*/, RegionRange /*regions*/,
93       SmallVectorImpl<Type> &inferredReturnTypes) {
94     // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
95     // support quantization or sparsity.
96     if (operands.empty())
97       return emitOptionalError(
98           location,
99           "Expected non-empty operands for [CompatibleOperandsAndResultType]");
100 
101     if (failed(inferMostSpecificType(context, location, operands.getTypes(),
102                                      inferredReturnTypes)))
103       return failure();
104     return success();
105   }
106 
107   // This function is not going to be called automatically.
108   // It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS
109   // (see examples in hlo_ops.cc).
inferReturnTypeComponentsFromOperands(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)110   static LogicalResult inferReturnTypeComponentsFromOperands(
111       MLIRContext *context, Optional<Location> location,
112       ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
113       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
114     SmallVector<Type> inferredReturnTypes;
115     if (failed(inferReturnTypes(context, location, operands.getValues(),
116                                 attributes, regions, inferredReturnTypes)))
117       return failure();
118     auto inferredReturnType = inferredReturnTypes[0].cast<ShapedType>();
119     inferredReturnShapes.push_back(inferredReturnType);
120     return success();
121   }
122 
123  private:
124   // Cases of infer return shape with bounds (lhs and rhs are commutative):
125   //       Dim of lhs     Dim of rhs      Infer
126   //  c0:  3              3               3
127   //  c1:  3              ?               3
128   //  c2:  3              ?, bound=4      3
129   //  c3:  3              ?, bound=2      Error out
130   //  c4:  ?              ?               ?
131   //  c5:  ?              ?, bound=3      ?, bound=3
132   //  c6:  ?, bound=3     ?, bound=3      ?, bound=3
133   //  c7:  ?, bound=3     ?, bound=4      ?, bound=3
134   // This method generalizes it to multiple inputs: 1) get the static input dims
135   // (if any) as infer dim, and 2) get min of input bounds as infer bound
inferMostSpecificType(MLIRContext * context,Optional<Location> location,ValueTypeRange<ValueRange> inputTypes,SmallVectorImpl<Type> & inferredReturnTypes)136   static LogicalResult inferMostSpecificType(
137       MLIRContext *context, Optional<Location> location,
138       ValueTypeRange<ValueRange> inputTypes,
139       SmallVectorImpl<Type> &inferredReturnTypes) {
140     SmallVector<RankedTensorType> rankedTypes;
141     for (auto inputType : inputTypes)
142       if (auto rankedType = inputType.dyn_cast<RankedTensorType>())
143         rankedTypes.push_back(rankedType);
144     if (rankedTypes.empty()) {
145       inferredReturnTypes.push_back(inputTypes[0]);
146       return success();
147     }
148 
149     auto rank = rankedTypes[0].getRank();
150     SmallVector<int64_t> inferredDimSizes(rank, ShapedType::kDynamicSize);
151     SmallVector<int64_t> inferredBounds(rank, ShapedType::kDynamicSize);
152     for (auto rankedType : rankedTypes) {
153       SmallVector<int64_t> bounds;
154       if (auto encoding =
155               rankedType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>()) {
156         bounds = llvm::to_vector<4>(encoding.getBounds());
157       } else if (rankedType.getEncoding()) {
158         // TODO(zhouxin) infer sparsity encoding after b/238903065 is fixed.
159         inferredReturnTypes.push_back(inputTypes[0]);
160         return success();
161       }
162 
163       for (int dim = 0; dim < rank; ++dim) {
164         // Dimensions
165         auto dimSize = rankedType.getShape()[dim];
166         if (inferredDimSizes[dim] != ShapedType::kDynamicSize &&
167             dimSize != ShapedType::kDynamicSize &&
168             inferredDimSizes[dim] != dimSize)
169           return emitOptionalError(location, "Mismatch dimension size ",
170                                    inferredDimSizes[dim], " and ", dimSize,
171                                    " in dimension ", dim);
172         if (inferredDimSizes[dim] == ShapedType::kDynamicSize)
173           inferredDimSizes[dim] = dimSize;
174 
175         // Bounds
176         if (!bounds.empty() && bounds[dim] != ShapedType::kDynamicSize) {
177           if (inferredBounds[dim] == ShapedType::kDynamicSize) {
178             inferredBounds[dim] = bounds[dim];
179           } else {
180             inferredBounds[dim] = std::min(inferredBounds[dim], bounds[dim]);
181           }
182         }
183         // Error out case that the inferred bound is smaller than inferred dim
184         if (inferredBounds[dim] != ShapedType::kDynamicSize &&
185             inferredBounds[dim] < inferredDimSizes[dim])
186           return emitOptionalError(location,
187                                    "bound must not be less than static "
188                                    "dimension size but has bound ",
189                                    inferredBounds[dim], " vs static size ",
190                                    inferredDimSizes[dim], " in dimension ",
191                                    dim);
192         if (inferredDimSizes[dim] != ShapedType::kDynamicSize)
193           inferredBounds[dim] = ShapedType::kDynamicSize;
194       }
195     }
196 
197     Attribute encoding = nullptr;
198     if (llvm::any_of(inferredBounds,
199                      [](auto el) { return el != ShapedType::kDynamicSize; }))
200       encoding = TypeExtensionsAttr::get(context, inferredBounds);
201     inferredReturnTypes.push_back(RankedTensorType::get(
202         inferredDimSizes, rankedTypes[0].getElementType(), encoding));
203 
204     return success();
205   }
206 };
207 
208 }  // namespace OpTrait
209 }  // namespace mhlo
210 }  // namespace mlir
211 
212 #endif
213