1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for converting CHLO dialect to Linalg dialect.
17 
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
30 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"
32 #include "mlir/Dialect/Linalg/IR/Linalg.h"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/OperationSupport.h"
40 #include "mlir/IR/PatternMatch.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Support/LLVM.h"
43 #include "mlir/Support/LogicalResult.h"
44 #include "mlir/Transforms/DialectConversion.h"
45 
46 namespace mlir {
47 namespace mhlo {
48 namespace {
49 
50 struct ChloLegalizeToLinalgPass
51     : public mhlo::ChloLegalizeToLinalgPassBase<ChloLegalizeToLinalgPass> {
getDependentDialectsmlir::mhlo::__anon8f61c3810111::ChloLegalizeToLinalgPass52   void getDependentDialects(DialectRegistry& registry) const override {
53     registry
54         .insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
55                 tensor::TensorDialect, sparse_tensor::SparseTensorDialect>();
56   }
57 
runOnOperationmlir::mhlo::__anon8f61c3810111::ChloLegalizeToLinalgPass58   void runOnOperation() override {
59     MLIRContext* ctx = &getContext();
60     RewritePatternSet patterns(ctx);
61     ConversionTarget target(*ctx);
62     mhlo::RemoveSignTypeConverter typeConverter;
63     mhlo::populateLegalizeSparseChloToLinalgPatterns(ctx, typeConverter,
64                                                      &patterns);
65     target.addLegalDialect<bufferization::BufferizationDialect,
66                            linalg::LinalgDialect, tensor::TensorDialect,
67                            sparse_tensor::SparseTensorDialect>();
68     target.addIllegalDialect<chlo::ChloDialect>();
69     /// The unary operation is sparse computation if either the input or the
70     /// result is a sparse tensor.
71     /// TODO(bixia): Remove the convert of such sparse CHLO ops from
72     /// chlo_legalize_to_hlo.
73     auto isNotSparseOp = [](Operation* op) {
74       auto encDst =
75           sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType());
76       auto encSrc =
77           sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType());
78       return !encDst && !encSrc;
79     };
80     target.addDynamicallyLegalOp<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp,
81                                  chlo::AtanhOp, chlo::BesselI1eOp, chlo::SinhOp,
82                                  chlo::TanOp>(isNotSparseOp);
83     if (failed(applyPartialConversion(getOperation(), target,
84                                       std::move(patterns)))) {
85       return signalPassFailure();
86     }
87   }
88 };
89 
90 }  // namespace
91 
92 namespace impl {
93 /// Converts unary chlo op to a scalar op.
94 ///
95 /// Since the CHLO ops require tensor operands, we first create a single element
96 /// from the tensor, then perform the CHLO ops, and extract the scalar result
97 /// from the tensor. This may introduce memory accesses overhead.
98 /// TODO(bixia): Remove the extra memory accesses for performance.
99 #define ADD_OP(OpTy)                                                           \
100   template <>                                                                  \
101   Value mapMhloOpToStdScalarOp<OpTy>(Location loc, ArrayRef<Type> resultTypes, \
102                                      ArrayRef<Type> /*arg_types*/,             \
103                                      ValueRange args, OpBuilder * b) {         \
104     Type innerResultTy = resultTypes[0];                                       \
105     RankedTensorType tensorResultTy =                                          \
106         RankedTensorType::get({}, innerResultTy);                              \
107     Value tensorArg =                                                          \
108         b->create<tensor::FromElementsOp>(loc, tensorResultTy, args[0]);       \
109     Value tensorResult =                                                       \
110         b->create<OpTy>(loc, tensorResultTy, ValueRange({tensorArg}));         \
111     Value innerResult =                                                        \
112         b->create<tensor::ExtractOp>(loc, tensorResult, ValueRange({}));       \
113     return innerResult;                                                        \
114   }
115 
116 ADD_OP(chlo::AsinOp)
117 ADD_OP(chlo::AsinhOp)
118 ADD_OP(chlo::AtanOp)
119 ADD_OP(chlo::AtanhOp)
120 ADD_OP(chlo::BesselI1eOp)
121 ADD_OP(chlo::SinhOp)
122 ADD_OP(chlo::TanOp)
123 
124 #undef ADD_OP
125 
126 }  // namespace impl
127 
populateLegalizeSparseChloToLinalgPatterns(MLIRContext * context,TypeConverter & typeConverter,RewritePatternSet * patterns)128 void populateLegalizeSparseChloToLinalgPatterns(MLIRContext* context,
129                                                 TypeConverter& typeConverter,
130                                                 RewritePatternSet* patterns) {
131   patterns->add<PointwiseToLinalgConverter<chlo::AsinOp>,
132                 PointwiseToLinalgConverter<chlo::AsinhOp>,
133                 PointwiseToLinalgConverter<chlo::AtanOp>,
134                 PointwiseToLinalgConverter<chlo::AtanhOp>,
135                 PointwiseToLinalgConverter<chlo::SinhOp>,
136                 PointwiseToLinalgConverter<chlo::TanOp>,
137                 PointwiseToLinalgConverter<chlo::BesselI1eOp>>(typeConverter,
138                                                                context);
139 }
140 
141 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeSparseChloToLinalgPass()142 createLegalizeSparseChloToLinalgPass() {
143   return std::make_unique<ChloLegalizeToLinalgPass>();
144 }
145 
146 }  // namespace mhlo
147 
148 }  // namespace mlir
149