1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for converting CHLO dialect to Linalg dialect.
17
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
30 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/Linalg/IR/Linalg.h"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/OperationSupport.h"
40 #include "mlir/IR/PatternMatch.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Support/LLVM.h"
43 #include "mlir/Support/LogicalResult.h"
44 #include "mlir/Transforms/DialectConversion.h"
45
46 namespace mlir {
47 namespace mhlo {
48 namespace {
49
50 struct ChloLegalizeToLinalgPass
51 : public mhlo::ChloLegalizeToLinalgPassBase<ChloLegalizeToLinalgPass> {
getDependentDialectsmlir::mhlo::__anon8f61c3810111::ChloLegalizeToLinalgPass52 void getDependentDialects(DialectRegistry& registry) const override {
53 registry
54 .insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
55 tensor::TensorDialect, sparse_tensor::SparseTensorDialect>();
56 }
57
runOnOperationmlir::mhlo::__anon8f61c3810111::ChloLegalizeToLinalgPass58 void runOnOperation() override {
59 MLIRContext* ctx = &getContext();
60 RewritePatternSet patterns(ctx);
61 ConversionTarget target(*ctx);
62 mhlo::RemoveSignTypeConverter typeConverter;
63 mhlo::populateLegalizeSparseChloToLinalgPatterns(ctx, typeConverter,
64 &patterns);
65 target.addLegalDialect<bufferization::BufferizationDialect,
66 linalg::LinalgDialect, tensor::TensorDialect,
67 sparse_tensor::SparseTensorDialect>();
68 target.addIllegalDialect<chlo::ChloDialect>();
69 /// The unary operation is sparse computation if either the input or the
70 /// result is a sparse tensor.
71 /// TODO(bixia): Remove the convert of such sparse CHLO ops from
72 /// chlo_legalize_to_hlo.
73 auto isNotSparseOp = [](Operation* op) {
74 auto encDst =
75 sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType());
76 auto encSrc =
77 sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType());
78 return !encDst && !encSrc;
79 };
80 target.addDynamicallyLegalOp<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp,
81 chlo::AtanhOp, chlo::BesselI1eOp, chlo::SinhOp,
82 chlo::TanOp>(isNotSparseOp);
83 if (failed(applyPartialConversion(getOperation(), target,
84 std::move(patterns)))) {
85 return signalPassFailure();
86 }
87 }
88 };
89
90 } // namespace
91
92 namespace impl {
93 /// Converts unary chlo op to a scalar op.
94 ///
95 /// Since the CHLO ops require tensor operands, we first create a single element
96 /// from the tensor, then perform the CHLO ops, and extract the scalar result
97 /// from the tensor. This may introduce memory accesses overhead.
98 /// TODO(bixia): Remove the extra memory accesses for performance.
99 #define ADD_OP(OpTy) \
100 template <> \
101 Value mapMhloOpToStdScalarOp<OpTy>(Location loc, ArrayRef<Type> resultTypes, \
102 ArrayRef<Type> /*arg_types*/, \
103 ValueRange args, OpBuilder * b) { \
104 Type innerResultTy = resultTypes[0]; \
105 RankedTensorType tensorResultTy = \
106 RankedTensorType::get({}, innerResultTy); \
107 Value tensorArg = \
108 b->create<tensor::FromElementsOp>(loc, tensorResultTy, args[0]); \
109 Value tensorResult = \
110 b->create<OpTy>(loc, tensorResultTy, ValueRange({tensorArg})); \
111 Value innerResult = \
112 b->create<tensor::ExtractOp>(loc, tensorResult, ValueRange({})); \
113 return innerResult; \
114 }
115
116 ADD_OP(chlo::AsinOp)
117 ADD_OP(chlo::AsinhOp)
118 ADD_OP(chlo::AtanOp)
119 ADD_OP(chlo::AtanhOp)
120 ADD_OP(chlo::BesselI1eOp)
121 ADD_OP(chlo::SinhOp)
122 ADD_OP(chlo::TanOp)
123
124 #undef ADD_OP
125
126 } // namespace impl
127
populateLegalizeSparseChloToLinalgPatterns(MLIRContext * context,TypeConverter & typeConverter,RewritePatternSet * patterns)128 void populateLegalizeSparseChloToLinalgPatterns(MLIRContext* context,
129 TypeConverter& typeConverter,
130 RewritePatternSet* patterns) {
131 patterns->add<PointwiseToLinalgConverter<chlo::AsinOp>,
132 PointwiseToLinalgConverter<chlo::AsinhOp>,
133 PointwiseToLinalgConverter<chlo::AtanOp>,
134 PointwiseToLinalgConverter<chlo::AtanhOp>,
135 PointwiseToLinalgConverter<chlo::SinhOp>,
136 PointwiseToLinalgConverter<chlo::TanOp>,
137 PointwiseToLinalgConverter<chlo::BesselI1eOp>>(typeConverter,
138 context);
139 }
140
141 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeSparseChloToLinalgPass()142 createLegalizeSparseChloToLinalgPass() {
143 return std::make_unique<ChloLegalizeToLinalgPass>();
144 }
145
146 } // namespace mhlo
147
148 } // namespace mlir
149