1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iterator>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Casting.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
34 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
35 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
36 #include "mlir/IR/Attributes.h" // from @llvm-project
37 #include "mlir/IR/Builders.h" // from @llvm-project
38 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
41 #include "mlir/IR/Diagnostics.h" // from @llvm-project
42 #include "mlir/IR/Location.h" // from @llvm-project
43 #include "mlir/IR/MLIRContext.h" // from @llvm-project
44 #include "mlir/IR/Matchers.h" // from @llvm-project
45 #include "mlir/IR/OperationSupport.h" // from @llvm-project
46 #include "mlir/IR/SymbolTable.h" // from @llvm-project
47 #include "mlir/IR/Types.h" // from @llvm-project
48 #include "mlir/IR/Value.h" // from @llvm-project
49 #include "mlir/IR/Visitors.h" // from @llvm-project
50 #include "mlir/Pass/Pass.h" // from @llvm-project
51 #include "mlir/Support/LLVM.h" // from @llvm-project
52 #include "mlir/Support/LogicalResult.h" // from @llvm-project
53 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
54 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
55 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
56 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
57 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
58 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
59 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
60 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
61 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
62
63 //===----------------------------------------------------------------------===//
64 // The pass to rewrite the TFR function call ops by TF ops. The callee of the
65 // TFR function call defines the signatures of the TF ops.
66 //
67 namespace mlir {
68 namespace TFR {
69
70 namespace {
71
72 // This pattern is to rewrite the "tfr.call" op and the "tfr.cast" ops on the
73 // operands by a TF op with "tfr.cast" ops on the results. The result type of
74 // the new TF op is an unranked tensor with element type derived.
75 class RewriteTFRCallOp : public OpRewritePattern<CallOp> {
76 using OpRewritePattern<CallOp>::OpRewritePattern;
77
78 public:
RewriteTFRCallOp(MLIRContext * context,const SymbolTable & table,bool materialize_derived_attrs)79 explicit RewriteTFRCallOp(MLIRContext* context, const SymbolTable& table,
80 bool materialize_derived_attrs)
81 : OpRewritePattern<CallOp>(context),
82 symbol_table_(table),
83 materialize_derived_attrs_(materialize_derived_attrs) {}
84
85 LogicalResult matchAndRewrite(CallOp call_op,
86 PatternRewriter& rewriter) const override;
87
88 private:
89 // Derives the attribute values for the attributes attached to the
90 // `input_tfr_type`. These attributes are only for the element type of the
91 // inputs, and these type information has been collected in the `input_types`.
92 // The result is stored in `derived_attrs` as the named attributes. Returns
93 // failure if the attributes stored in the `input_tfr_type` violates the
94 // assumptions.
95 LogicalResult AddDerivedAttrs(
96 PatternRewriter& rewriter, Type input_tfr_type,
97 ArrayRef<Attribute> input_types,
98 llvm::StringMap<Attribute>* derived_attrs) const;
99
100 // Collects the operands and attributes for the TF op. At the same time, it
101 // collects all the derived attribute values to derive the output types of the
102 // TF op.
103 LogicalResult CollectInputsAndAttributes(
104 PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
105 SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
106 llvm::StringMap<Attribute>* derived_attrs) const;
107
108 // Uses the collected attribute values to derive all the output types.
109 LogicalResult DeriveOutputTypes(Location loc, FunctionType signature,
110 const llvm::StringMap<Attribute>& attrs,
111 SmallVectorImpl<Type>* output_types) const;
112
113 // Creates the TF op and also the necessary tfr.cast ops to replace the
114 // original TFR call op.
115 LogicalResult CreateAndReplaceOp(
116 PatternRewriter& rewriter, CallOp call_op,
117 const SmallVectorImpl<Type>& output_types,
118 const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
119 const llvm::StringMap<Attribute>& derived_attrs) const;
120
121 // Converts the attribute to the specific type.
122 Attribute ProcessAttributeValue(Attribute attr, StringAttr attr_type) const;
123
GetFixedElementType(StringRef element_type,Builder & builder) const124 Type GetFixedElementType(StringRef element_type, Builder& builder) const {
125 if (element_type == "i32_") return builder.getI32Type();
126 if (element_type == "i64_") return builder.getI64Type();
127 if (element_type == "f32_") return builder.getF32Type();
128 if (element_type == "i1_") return builder.getI1Type();
129 return {};
130 }
131
132 // Adds a tf.Cast op if the tfr.tensor attribute indicated a fixed element
133 // type.
134 // TODO(fengliuai): This method is required when the operand types are not set
135 // by the frontend correctly.
CastToNonDerivedType(PatternRewriter & rewriter,Location loc,CastOp cast_op,Type input_tfr_type) const136 Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc,
137 CastOp cast_op, Type input_tfr_type) const {
138 auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>();
139 if (!tensor_type) return cast_op.arg();
140
141 auto attr_names = tensor_type.getAttrKeys();
142 if (attr_names.empty() || attr_names.size() > 1) return cast_op.arg();
143 StringRef tfr_type_attr = attr_names[0].getValue();
144 if (!fixed_elt_type_attrs_.contains(tfr_type_attr)) return cast_op.arg();
145
146 Type result_elt_type = GetFixedElementType(tfr_type_attr, rewriter);
147 if (!result_elt_type) {
148 return cast_op.arg();
149 }
150
151 Type original_input_type =
152 cast_op.getInputElementType().cast<TypeAttr>().getValue();
153 if (result_elt_type != original_input_type) {
154 UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type);
155 return rewriter.create<TF::CastOp>(loc, result_type, cast_op.arg());
156 }
157 return cast_op.arg();
158 }
159
160 // For variadic operands, we have to enforce them to use the same types.
161 // TODO(fengliuai): This method is required when the operand types are not set
162 // by the frontend correctly.
CastValuesToSameType(PatternRewriter & rewriter,Location loc,const llvm::SmallVectorImpl<Attribute> & input_types,llvm::SmallVectorImpl<Value> & input_values) const163 void CastValuesToSameType(PatternRewriter& rewriter, Location loc,
164 const llvm::SmallVectorImpl<Attribute>& input_types,
165 llvm::SmallVectorImpl<Value>& input_values) const {
166 if (input_types.size() <= 1) return;
167
168 Type target_input_type = input_types[0].cast<TypeAttr>().getValue();
169 auto result_type = UnrankedTensorType::get(target_input_type);
170 for (auto i = 1; i < input_types.size(); ++i) {
171 Type current_input_type = input_types[i].cast<TypeAttr>().getValue();
172 if (current_input_type != target_input_type) {
173 input_values[i] =
174 rewriter.create<TF::CastOp>(loc, result_type, input_values[i]);
175 }
176 }
177 }
178
179 const SymbolTable& symbol_table_;
180 const bool materialize_derived_attrs_;
181 const llvm::SmallDenseSet<StringRef, 4> fixed_elt_type_attrs_{"i32_", "i64_",
182 "f32_", "i1_"};
183 };
184
AddDerivedAttrs(PatternRewriter & rewriter,Type input_tfr_type,ArrayRef<Attribute> input_types,llvm::StringMap<Attribute> * derived_attrs) const185 LogicalResult RewriteTFRCallOp::AddDerivedAttrs(
186 PatternRewriter& rewriter, Type input_tfr_type,
187 ArrayRef<Attribute> input_types,
188 llvm::StringMap<Attribute>* derived_attrs) const {
189 // If there is an attribute associated to the input in the signature, we
190 // store it as an derived attribute.
191 if (auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>()) {
192 auto attr_names = tensor_type.getAttrKeys();
193 if (attr_names.empty()) return success();
194
195 if (attr_names.size() == 1) {
196 derived_attrs->insert({attr_names[0].getValue(), input_types[0]});
197 return success();
198 }
199 }
200
201 // If there is an attribute associated to the input in the signature,
202 // we store it as an derived attribute.
203 if (auto list_type = input_tfr_type.dyn_cast<TFRTensorListType>()) {
204 auto attr_names = list_type.getAttrKeys();
205 if (attr_names.empty()) return success();
206
207 // N*T case
208 if (attr_names.size() == 2) {
209 derived_attrs->insert({attr_names[0].getValue(),
210 rewriter.getI32IntegerAttr(input_types.size())});
211 // Note that this uses the first element of the list to infer the T value.
212 // A tf.Cast is required to cast the other inputs to the same type.
213 derived_attrs->insert({attr_names[1].getValue(), input_types[0]});
214 return success();
215 }
216
217 // list(dtype) case
218 if (attr_names.size() == 1) {
219 derived_attrs->insert(
220 {attr_names[0].getValue(), rewriter.getArrayAttr(input_types)});
221 return success();
222 }
223 }
224
225 return failure();
226 }
227
CollectInputsAndAttributes(PatternRewriter & rewriter,TFRFuncOp signature,CallOp call_op,SmallVectorImpl<Value> * inputs,NamedAttrList * arg_attrs,llvm::StringMap<Attribute> * derived_attrs) const228 LogicalResult RewriteTFRCallOp::CollectInputsAndAttributes(
229 PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
230 SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
231 llvm::StringMap<Attribute>* derived_attrs) const {
232 for (const auto& operand :
233 llvm::enumerate(signature.getFunctionType().getInputs())) {
234 // If the index is larger than the operand number of the call_op, the
235 // default value of the operand needs to be used.
236 if (operand.index() >= call_op.getNumOperands()) {
237 auto attr_name = signature.getArgAttrOfType<StringAttr>(
238 operand.index(), kAttrArgumentNameAttr);
239 auto attr_value =
240 signature.getArgAttr(operand.index(), kAttrArgumentDefaultAttr);
241 arg_attrs->push_back(
242 rewriter.getNamedAttr(attr_name.getValue(), attr_value));
243 continue;
244 }
245
246 // The index is valid for the call_op.
247 Value input = call_op.getOperand(operand.index());
248 Operation* input_op = input.getDefiningOp();
249 auto input_tfr_type =
250 signature.getFunctionType().getInputs()[operand.index()];
251
252 // There are three cases for the preceding input_op:
253
254 // 1. The preceding op can be a tfr.cast op, which will be fused to the
255 // current op, so the result op has input with tensor type.
256 if (auto cast_op = dyn_cast_or_null<CastOp>(input_op)) {
257 Value input_to_cast = CastToNonDerivedType(rewriter, call_op.getLoc(),
258 cast_op, input_tfr_type);
259 inputs->push_back(input_to_cast);
260 if (failed(AddDerivedAttrs(rewriter, input_tfr_type,
261 {cast_op.getInputElementType()},
262 derived_attrs))) {
263 return failure();
264 }
265 continue;
266 }
267
268 // 2. The preceding op is a tfr.build_list op, which collects multiple
269 // values with tensor types via the tfr.cast ops. These ops will be fused
270 // to the current op as well, so all the tfr.cast op inputs will be inputs
271 // to the result op.
272 if (auto list_op = dyn_cast_or_null<BuildListOp>(input_op)) {
273 // Find out all the inputs to the build list op
274 // TODO(fengliuai): make build_list op only take tensor argument
275 llvm::SmallVector<Attribute, 4> list_input_types;
276 llvm::SmallVector<Value, 4> list_inputs;
277 for (auto list_input : list_op.getOperands()) {
278 auto cast_op = dyn_cast_or_null<CastOp>(list_input.getDefiningOp());
279 if (!cast_op) return failure();
280 list_inputs.push_back(cast_op.arg());
281 list_input_types.push_back(cast_op.getInputElementType());
282 }
283 CastValuesToSameType(rewriter, call_op.getLoc(), list_input_types,
284 list_inputs);
285 inputs->append(list_inputs.begin(), list_inputs.end());
286 if (failed(AddDerivedAttrs(rewriter, input_tfr_type, list_input_types,
287 derived_attrs))) {
288 return failure();
289 }
290 continue;
291 }
292
293 // 3. The preceding op is a constant, thus the value of this constant is
294 // used to create an attribute of the result op, according to the signature.
295 Attribute arg_value;
296 // A failure indicates the argument isn't a constant value, so we should
297 // not use it as an attribute.
298 if (!matchPattern(input, m_Constant(&arg_value))) {
299 return failure();
300 }
301 auto attr_name = signature.getArgAttrOfType<StringAttr>(
302 operand.index(), kAttrArgumentNameAttr);
303 auto attr_type = signature.getArgAttrOfType<StringAttr>(
304 operand.index(), kAttrArgumentTypeAttr);
305 auto value = ProcessAttributeValue(arg_value, attr_type);
306 arg_attrs->push_back(rewriter.getNamedAttr(attr_name.getValue(), value));
307 }
308 return success();
309 }
310
ProcessAttributeValue(Attribute attr,StringAttr attr_type) const311 Attribute RewriteTFRCallOp::ProcessAttributeValue(Attribute attr,
312 StringAttr attr_type) const {
313 if (!attr_type) return attr;
314
315 if (attr_type.getValue() == "tensor") {
316 if (auto f = attr.dyn_cast<FloatAttr>()) {
317 RankedTensorType type = RankedTensorType::get({}, f.getType());
318 return DenseFPElementsAttr::get(type, attr);
319 }
320 // TODO(fengliuai): handles ArrayAttr. Note that it can be nested ArrayAttr.
321 }
322
323 return attr;
324 }
325
326 // For each output, uses the attribute name associated to the tfr types to find
327 // out the attribute value from the collected `attrs` and create the output type
328 // of the result op by using the attribute value as the element type.
DeriveOutputTypes(Location loc,FunctionType signature,const llvm::StringMap<Attribute> & attrs,SmallVectorImpl<Type> * output_types) const329 LogicalResult RewriteTFRCallOp::DeriveOutputTypes(
330 Location loc, FunctionType signature,
331 const llvm::StringMap<Attribute>& attrs,
332 SmallVectorImpl<Type>* output_types) const {
333 for (auto res : llvm::enumerate(signature.getResults())) {
334 if (auto tensor_type = res.value().dyn_cast<TFRTensorType>()) {
335 // tfr.tensor should only have one attribute attached.
336 auto attr_key = tensor_type.getAttrKeys().front();
337 Builder builder(signature.getContext());
338 if (auto attr = attrs.lookup(attr_key.getValue())) {
339 output_types->push_back(
340 UnrankedTensorType::get(attr.cast<TypeAttr>().getValue()));
341 } else if (Type element_type =
342 GetFixedElementType(attr_key.getValue(), builder)) {
343 output_types->push_back(UnrankedTensorType::get(element_type));
344 } else {
345 emitError(loc) << "type " << attr_key.getValue()
346 << " can't be resolved for the signature of the op";
347 return failure();
348 }
349 continue;
350 }
351
352 if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
353 // There are two cases: N*T or list(dtype)
354 auto attr_keys = list_type.getAttrKeys();
355 // N*T case
356 if (attr_keys.size() == 2) {
357 // The first one is N, and the second one is T
358 int list_size =
359 attrs.lookup(attr_keys[0].getValue()).cast<IntegerAttr>().getInt();
360 Type list_type =
361 attrs.lookup(attr_keys[1].getValue()).cast<TypeAttr>().getValue();
362 for (int i = 0; i < list_size; ++i) {
363 output_types->push_back(UnrankedTensorType::get(list_type));
364 }
365 continue;
366 }
367 // TODO(fengliuai): list(dtype) case
368 }
369 return failure();
370 }
371 return success();
372 }
373
CreateAndReplaceOp(PatternRewriter & rewriter,CallOp call_op,const SmallVectorImpl<Type> & output_types,const SmallVectorImpl<Value> & inputs,const NamedAttrList & attr_list,const llvm::StringMap<Attribute> & derived_attrs) const374 LogicalResult RewriteTFRCallOp::CreateAndReplaceOp(
375 PatternRewriter& rewriter, CallOp call_op,
376 const SmallVectorImpl<Type>& output_types,
377 const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
378 const llvm::StringMap<Attribute>& derived_attrs) const {
379 // Create the new op
380 Location loc = call_op.getLoc();
381 rewriter.setInsertionPointAfter(call_op);
382 std::string tf_op_name = GetTFOpName(call_op.callee());
383 OperationState new_state(loc, tf_op_name, inputs, output_types, attr_list);
384 Operation* new_op = rewriter.create(new_state);
385 if (materialize_derived_attrs_) {
386 for (const auto& attr : derived_attrs) {
387 // Add or update the derived attribute with the value. Skip the fixed
388 // element type attributes, in case they are present in the NodeDef.
389 if (!fixed_elt_type_attrs_.contains(attr.first())) {
390 new_op->setAttr(attr.first(), attr.second);
391 }
392 }
393 }
394 // Create the tfr.cast ops on the results and replace the uses of the
395 // original call op.
396 TFRTensorType unconstrainted_type = rewriter.getType<TFRTensorType>();
397 SmallVector<Value, 4> new_results;
398 for (auto res : llvm::enumerate(call_op.getResultTypes())) {
399 Type res_type = res.value();
400 if (res_type.dyn_cast<TFRTensorType>()) {
401 Value new_res = new_op->getResult(res.index());
402 auto casted = rewriter.create<CastOp>(loc, res_type, new_res);
403 new_results.push_back(casted.out());
404 } else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
405 SmallVector<Value, 4> tensor_list;
406 for (int i = res.index(); i < new_op->getNumResults(); i++) {
407 Value new_res = new_op->getResult(i);
408 auto casted =
409 rewriter.create<CastOp>(loc, unconstrainted_type, new_res);
410 tensor_list.push_back(casted.out());
411 }
412 auto list_op = rewriter.create<BuildListOp>(loc, res_type, tensor_list);
413 new_results.push_back(list_op.out());
414 }
415 }
416
417 // Copy all the allowed attributes to the new op.
418 if (failed(CopyNonSymbolRefAttrs(call_op, new_op))) return failure();
419
420 rewriter.replaceOp(call_op, new_results);
421 return success();
422 }
423
matchAndRewrite(CallOp call_op,PatternRewriter & rewriter) const424 LogicalResult RewriteTFRCallOp::matchAndRewrite(
425 CallOp call_op, PatternRewriter& rewriter) const {
426 // Get the func op and verify that it is external. The type of this external
427 // func op is used as the signature of the corresponding TF ops. All the
428 // external func ops have the trailing underscore.
429 std::string external_callee_name = call_op.callee().str().append("_");
430 TFRFuncOp func = symbol_table_.lookup<TFRFuncOp>(external_callee_name);
431 if (!func || !func.isExternal()) return failure();
432 // Get the inputs and attributes. The attributes include these from the
433 // argument list and also these derived from the inputs.
434 SmallVector<Value, 4> inputs;
435 NamedAttrList argument_attrs;
436 llvm::StringMap<Attribute> derived_attrs;
437 if (failed(CollectInputsAndAttributes(rewriter, func, call_op, &inputs,
438 &argument_attrs, &derived_attrs))) {
439 return failure();
440 }
441
442 // Derive the output types. The result type is derived by using the
443 // attributes attched to the result type of the signature. The attribute
444 // value should be either in the attribute argument list or the derived
445 // attribute from the input tensors. All the result type
446 // are unranked, and shape inference should be applied afterwards.
447 SmallVector<Type, 4> output_types;
448
449 // Merge the attributes from the argument list to the derived ones.
450 for (auto& attr : argument_attrs) {
451 derived_attrs.insert({attr.getName(), attr.getValue()});
452 }
453
454 // Derive the output types by using the attributes attached to the tfr
455 // types.
456 if (failed(DeriveOutputTypes(call_op->getLoc(), func.getFunctionType(),
457 derived_attrs, &output_types))) {
458 return failure();
459 }
460
461 // Create the new op and replace the old TFR call op.
462 return CreateAndReplaceOp(rewriter, call_op, output_types, inputs,
463 argument_attrs, derived_attrs);
464 }
465
466 // Raise TFR call ops to the TF ops.
467 class RaiseToTFOpsPass
468 : public PassWrapper<RaiseToTFOpsPass, OperationPass<func::FuncOp>> {
469 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseToTFOpsPass)470 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseToTFOpsPass)
471
472 void getDependentDialects(DialectRegistry& registry) const override {
473 registry.insert<TFRDialect, TF::TensorFlowDialect, scf::SCFDialect,
474 arith::ArithmeticDialect, func::FuncDialect>();
475 }
476
RaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,bool materialize_derived_attrs)477 explicit RaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,
478 bool materialize_derived_attrs)
479 : external_tfr_module_(tfr_module),
480 materialize_derived_attrs_(materialize_derived_attrs) {}
481
getArgument() const482 StringRef getArgument() const final { return "tfr-raise-to-tf"; }
483
getDescription() const484 StringRef getDescription() const final {
485 return "Raise all the TFR call ops to TF ops.";
486 }
487
488 void runOnOperation() override;
489
490 private:
491 llvm::Optional<ModuleOp> external_tfr_module_;
492 const bool materialize_derived_attrs_;
493 };
494
runOnOperation()495 void RaiseToTFOpsPass::runOnOperation() {
496 func::FuncOp func = getOperation();
497 MLIRContext* ctx = &getContext();
498 SymbolTable table(external_tfr_module_.has_value()
499 ? *external_tfr_module_
500 : func->getParentOfType<ModuleOp>());
501
502 RewritePatternSet patterns(&getContext());
503 patterns.add<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs_);
504
505 populateCanonicalizationPatterns(func, patterns);
506
507 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
508 }
509 } // namespace
510
511 // Creates an instance of the pass to raise TFR call ops to the TF ops.
CreateRaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,bool materialize_derived_attrs)512 std::unique_ptr<OperationPass<func::FuncOp>> CreateRaiseToTFOpsPass(
513 llvm::Optional<ModuleOp> tfr_module, bool materialize_derived_attrs) {
514 return std::make_unique<RaiseToTFOpsPass>(tfr_module,
515 materialize_derived_attrs);
516 }
517
__anonf1e22a540202null518 static PassRegistration<RaiseToTFOpsPass> pass([] {
519 return CreateRaiseToTFOpsPass();
520 });
521
522 } // namespace TFR
523 } // namespace mlir
524