1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H 17 #define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H 18 19 #include <algorithm> 20 21 #include "llvm/ADT/Sequence.h" 22 #include "mlir/IR/OpDefinition.h" 23 #include "mlir/Interfaces/InferTypeOpInterface.h" 24 25 // Include order below matters. 26 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc" 27 #define GET_ATTRDEF_CLASSES 28 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h.inc" 29 30 namespace mlir { 31 namespace mhlo { 32 33 // Forward declaration for a function declared in hlo_ops.h. 34 bool isCompatibleForMhloTypeInference(Type tp1, Type tp2); 35 36 namespace OpTrait { 37 38 template <typename ConcreteType> 39 class BroadcastingElementwise 40 : public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {}; 41 42 template <typename ConcreteType> 43 class PairwiseSameOperandAndResultType 44 : public mlir::OpTrait::TraitBase<ConcreteType, 45 PairwiseSameOperandAndResultType> { 46 public: verifyTrait(Operation * op)47 static LogicalResult verifyTrait(Operation *op) { 48 const int numOperands = op->getNumOperands(); 49 const int numResults = op->getNumResults(); 50 if (numOperands != numResults) { 51 return op->emitOpError() 52 << "requires the same number of operands and results"; 53 } 54 55 for (int idx : llvm::seq<int>(0, numOperands)) { 56 if (op->getOperand(idx).getType() != op->getResult(idx).getType()) { 57 return op->emitOpError() 58 << "requires the same type for operand and result at index " 59 << idx; 60 } 61 } 62 return success(); 63 } 64 }; 65 66 template <typename ConcreteType> 67 class CompatibleOperandsAndResultType 68 : public mlir::OpTrait::TraitBase<ConcreteType, 69 CompatibleOperandsAndResultType> { 70 public: verifyTrait(Operation * op)71 static LogicalResult verifyTrait(Operation *op) { 72 Type expected; 73 if (op->getNumResults() != 0) expected = op->getResult(0).getType(); 74 if (op->getNumOperands() != 0) expected = op->getOperand(0).getType(); 75 if (!expected) return failure(); 76 77 auto typeMatch = [&](Type actual) { 78 return isCompatibleForMhloTypeInference(actual, expected); 79 }; 80 auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) && 81 llvm::all_of(op->getResultTypes(), typeMatch); 82 if (!allMatch) { 83 return op->emitOpError( 84 "requires compatible types for all operands and results"); 85 } 86 87 return success(allMatch); 88 } 89 inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)90 static LogicalResult inferReturnTypes( 91 MLIRContext *context, Optional<Location> location, ValueRange operands, 92 DictionaryAttr /*attributes*/, RegionRange /*regions*/, 93 SmallVectorImpl<Type> &inferredReturnTypes) { 94 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that 95 // support quantization or sparsity. 96 if (operands.empty()) 97 return emitOptionalError( 98 location, 99 "Expected non-empty operands for [CompatibleOperandsAndResultType]"); 100 101 if (failed(inferMostSpecificType(context, location, operands.getTypes(), 102 inferredReturnTypes))) 103 return failure(); 104 return success(); 105 } 106 107 // This function is not going to be called automatically. 108 // It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS 109 // (see examples in hlo_ops.cc). inferReturnTypeComponentsFromOperands(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)110 static LogicalResult inferReturnTypeComponentsFromOperands( 111 MLIRContext *context, Optional<Location> location, 112 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, 113 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 114 SmallVector<Type> inferredReturnTypes; 115 if (failed(inferReturnTypes(context, location, operands.getValues(), 116 attributes, regions, inferredReturnTypes))) 117 return failure(); 118 auto inferredReturnType = inferredReturnTypes[0].cast<ShapedType>(); 119 inferredReturnShapes.push_back(inferredReturnType); 120 return success(); 121 } 122 123 private: 124 // Cases of infer return shape with bounds (lhs and rhs are commutative): 125 // Dim of lhs Dim of rhs Infer 126 // c0: 3 3 3 127 // c1: 3 ? 3 128 // c2: 3 ?, bound=4 3 129 // c3: 3 ?, bound=2 Error out 130 // c4: ? ? ? 131 // c5: ? ?, bound=3 ?, bound=3 132 // c6: ?, bound=3 ?, bound=3 ?, bound=3 133 // c7: ?, bound=3 ?, bound=4 ?, bound=3 134 // This method generalizes it to multiple inputs: 1) get the static input dims 135 // (if any) as infer dim, and 2) get min of input bounds as infer bound inferMostSpecificType(MLIRContext * context,Optional<Location> location,ValueTypeRange<ValueRange> inputTypes,SmallVectorImpl<Type> & inferredReturnTypes)136 static LogicalResult inferMostSpecificType( 137 MLIRContext *context, Optional<Location> location, 138 ValueTypeRange<ValueRange> inputTypes, 139 SmallVectorImpl<Type> &inferredReturnTypes) { 140 SmallVector<RankedTensorType> rankedTypes; 141 for (auto inputType : inputTypes) 142 if (auto rankedType = inputType.dyn_cast<RankedTensorType>()) 143 rankedTypes.push_back(rankedType); 144 if (rankedTypes.empty()) { 145 inferredReturnTypes.push_back(inputTypes[0]); 146 return success(); 147 } 148 149 auto rank = rankedTypes[0].getRank(); 150 SmallVector<int64_t> inferredDimSizes(rank, ShapedType::kDynamicSize); 151 SmallVector<int64_t> inferredBounds(rank, ShapedType::kDynamicSize); 152 for (auto rankedType : rankedTypes) { 153 SmallVector<int64_t> bounds; 154 if (auto encoding = 155 rankedType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>()) { 156 bounds = llvm::to_vector<4>(encoding.getBounds()); 157 } else if (rankedType.getEncoding()) { 158 // TODO(zhouxin) infer sparsity encoding after b/238903065 is fixed. 159 inferredReturnTypes.push_back(inputTypes[0]); 160 return success(); 161 } 162 163 for (int dim = 0; dim < rank; ++dim) { 164 // Dimensions 165 auto dimSize = rankedType.getShape()[dim]; 166 if (inferredDimSizes[dim] != ShapedType::kDynamicSize && 167 dimSize != ShapedType::kDynamicSize && 168 inferredDimSizes[dim] != dimSize) 169 return emitOptionalError(location, "Mismatch dimension size ", 170 inferredDimSizes[dim], " and ", dimSize, 171 " in dimension ", dim); 172 if (inferredDimSizes[dim] == ShapedType::kDynamicSize) 173 inferredDimSizes[dim] = dimSize; 174 175 // Bounds 176 if (!bounds.empty() && bounds[dim] != ShapedType::kDynamicSize) { 177 if (inferredBounds[dim] == ShapedType::kDynamicSize) { 178 inferredBounds[dim] = bounds[dim]; 179 } else { 180 inferredBounds[dim] = std::min(inferredBounds[dim], bounds[dim]); 181 } 182 } 183 // Error out case that the inferred bound is smaller than inferred dim 184 if (inferredBounds[dim] != ShapedType::kDynamicSize && 185 inferredBounds[dim] < inferredDimSizes[dim]) 186 return emitOptionalError(location, 187 "bound must not be less than static " 188 "dimension size but has bound ", 189 inferredBounds[dim], " vs static size ", 190 inferredDimSizes[dim], " in dimension ", 191 dim); 192 if (inferredDimSizes[dim] != ShapedType::kDynamicSize) 193 inferredBounds[dim] = ShapedType::kDynamicSize; 194 } 195 } 196 197 Attribute encoding = nullptr; 198 if (llvm::any_of(inferredBounds, 199 [](auto el) { return el != ShapedType::kDynamicSize; })) 200 encoding = TypeExtensionsAttr::get(context, inferredBounds); 201 inferredReturnTypes.push_back(RankedTensorType::get( 202 inferredDimSizes, rankedTypes[0].getElementType(), encoding)); 203 204 return success(); 205 } 206 }; 207 208 } // namespace OpTrait 209 } // namespace mhlo 210 } // namespace mlir 211 212 #endif 213