xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <iterator>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Casting.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
34 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
35 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
36 #include "mlir/IR/Attributes.h"  // from @llvm-project
37 #include "mlir/IR/Builders.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
41 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
42 #include "mlir/IR/Location.h"  // from @llvm-project
43 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
44 #include "mlir/IR/Matchers.h"  // from @llvm-project
45 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
46 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
47 #include "mlir/IR/Types.h"  // from @llvm-project
48 #include "mlir/IR/Value.h"  // from @llvm-project
49 #include "mlir/IR/Visitors.h"  // from @llvm-project
50 #include "mlir/Pass/Pass.h"  // from @llvm-project
51 #include "mlir/Support/LLVM.h"  // from @llvm-project
52 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
53 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
54 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
55 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
56 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
57 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
58 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
59 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
60 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
61 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
62 
63 //===----------------------------------------------------------------------===//
64 // The pass to rewrite the TFR function call ops by TF ops. The callee of the
65 // TFR function call defines the signatures of the TF ops.
66 //
67 namespace mlir {
68 namespace TFR {
69 
70 namespace {
71 
72 // This pattern is to rewrite the "tfr.call" op and the "tfr.cast" ops on the
73 // operands by a TF op with "tfr.cast" ops on the results. The result type of
74 // the new TF op is an unranked tensor with element type derived.
75 class RewriteTFRCallOp : public OpRewritePattern<CallOp> {
76   using OpRewritePattern<CallOp>::OpRewritePattern;
77 
78  public:
RewriteTFRCallOp(MLIRContext * context,const SymbolTable & table,bool materialize_derived_attrs)79   explicit RewriteTFRCallOp(MLIRContext* context, const SymbolTable& table,
80                             bool materialize_derived_attrs)
81       : OpRewritePattern<CallOp>(context),
82         symbol_table_(table),
83         materialize_derived_attrs_(materialize_derived_attrs) {}
84 
85   LogicalResult matchAndRewrite(CallOp call_op,
86                                 PatternRewriter& rewriter) const override;
87 
88  private:
89   // Derives the attribute values for the attributes attached to the
90   // `input_tfr_type`. These attributes are only for the element type of the
91   // inputs, and these type information has been collected in the `input_types`.
92   // The result is stored in `derived_attrs` as the named attributes. Returns
93   // failure if the attributes stored in the `input_tfr_type` violates the
94   // assumptions.
95   LogicalResult AddDerivedAttrs(
96       PatternRewriter& rewriter, Type input_tfr_type,
97       ArrayRef<Attribute> input_types,
98       llvm::StringMap<Attribute>* derived_attrs) const;
99 
100   // Collects the operands and attributes for the TF op. At the same time, it
101   // collects all the derived attribute values to derive the output types of the
102   // TF op.
103   LogicalResult CollectInputsAndAttributes(
104       PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
105       SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
106       llvm::StringMap<Attribute>* derived_attrs) const;
107 
108   // Uses the collected attribute values to derive all the output types.
109   LogicalResult DeriveOutputTypes(Location loc, FunctionType signature,
110                                   const llvm::StringMap<Attribute>& attrs,
111                                   SmallVectorImpl<Type>* output_types) const;
112 
113   // Creates the TF op and also the necessary tfr.cast ops to replace the
114   // original TFR call op.
115   LogicalResult CreateAndReplaceOp(
116       PatternRewriter& rewriter, CallOp call_op,
117       const SmallVectorImpl<Type>& output_types,
118       const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
119       const llvm::StringMap<Attribute>& derived_attrs) const;
120 
121   // Converts the attribute to the specific type.
122   Attribute ProcessAttributeValue(Attribute attr, StringAttr attr_type) const;
123 
GetFixedElementType(StringRef element_type,Builder & builder) const124   Type GetFixedElementType(StringRef element_type, Builder& builder) const {
125     if (element_type == "i32_") return builder.getI32Type();
126     if (element_type == "i64_") return builder.getI64Type();
127     if (element_type == "f32_") return builder.getF32Type();
128     if (element_type == "i1_") return builder.getI1Type();
129     return {};
130   }
131 
132   // Adds a tf.Cast op if the tfr.tensor attribute indicated a fixed element
133   // type.
134   // TODO(fengliuai): This method is required when the operand types are not set
135   // by the frontend correctly.
CastToNonDerivedType(PatternRewriter & rewriter,Location loc,CastOp cast_op,Type input_tfr_type) const136   Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc,
137                              CastOp cast_op, Type input_tfr_type) const {
138     auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>();
139     if (!tensor_type) return cast_op.arg();
140 
141     auto attr_names = tensor_type.getAttrKeys();
142     if (attr_names.empty() || attr_names.size() > 1) return cast_op.arg();
143     StringRef tfr_type_attr = attr_names[0].getValue();
144     if (!fixed_elt_type_attrs_.contains(tfr_type_attr)) return cast_op.arg();
145 
146     Type result_elt_type = GetFixedElementType(tfr_type_attr, rewriter);
147     if (!result_elt_type) {
148       return cast_op.arg();
149     }
150 
151     Type original_input_type =
152         cast_op.getInputElementType().cast<TypeAttr>().getValue();
153     if (result_elt_type != original_input_type) {
154       UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type);
155       return rewriter.create<TF::CastOp>(loc, result_type, cast_op.arg());
156     }
157     return cast_op.arg();
158   }
159 
160   // For variadic operands, we have to enforce them to use the same types.
161   // TODO(fengliuai): This method is required when the operand types are not set
162   // by the frontend correctly.
CastValuesToSameType(PatternRewriter & rewriter,Location loc,const llvm::SmallVectorImpl<Attribute> & input_types,llvm::SmallVectorImpl<Value> & input_values) const163   void CastValuesToSameType(PatternRewriter& rewriter, Location loc,
164                             const llvm::SmallVectorImpl<Attribute>& input_types,
165                             llvm::SmallVectorImpl<Value>& input_values) const {
166     if (input_types.size() <= 1) return;
167 
168     Type target_input_type = input_types[0].cast<TypeAttr>().getValue();
169     auto result_type = UnrankedTensorType::get(target_input_type);
170     for (auto i = 1; i < input_types.size(); ++i) {
171       Type current_input_type = input_types[i].cast<TypeAttr>().getValue();
172       if (current_input_type != target_input_type) {
173         input_values[i] =
174             rewriter.create<TF::CastOp>(loc, result_type, input_values[i]);
175       }
176     }
177   }
178 
179   const SymbolTable& symbol_table_;
180   const bool materialize_derived_attrs_;
181   const llvm::SmallDenseSet<StringRef, 4> fixed_elt_type_attrs_{"i32_", "i64_",
182                                                                 "f32_", "i1_"};
183 };
184 
AddDerivedAttrs(PatternRewriter & rewriter,Type input_tfr_type,ArrayRef<Attribute> input_types,llvm::StringMap<Attribute> * derived_attrs) const185 LogicalResult RewriteTFRCallOp::AddDerivedAttrs(
186     PatternRewriter& rewriter, Type input_tfr_type,
187     ArrayRef<Attribute> input_types,
188     llvm::StringMap<Attribute>* derived_attrs) const {
189   // If there is an attribute associated to the input in the signature, we
190   // store it as an derived attribute.
191   if (auto tensor_type = input_tfr_type.dyn_cast<TFRTensorType>()) {
192     auto attr_names = tensor_type.getAttrKeys();
193     if (attr_names.empty()) return success();
194 
195     if (attr_names.size() == 1) {
196       derived_attrs->insert({attr_names[0].getValue(), input_types[0]});
197       return success();
198     }
199   }
200 
201   // If there is an attribute associated to the input in the signature,
202   // we store it as an derived attribute.
203   if (auto list_type = input_tfr_type.dyn_cast<TFRTensorListType>()) {
204     auto attr_names = list_type.getAttrKeys();
205     if (attr_names.empty()) return success();
206 
207     // N*T case
208     if (attr_names.size() == 2) {
209       derived_attrs->insert({attr_names[0].getValue(),
210                              rewriter.getI32IntegerAttr(input_types.size())});
211       // Note that this uses the first element of the list to infer the T value.
212       // A tf.Cast is required to cast the other inputs to the same type.
213       derived_attrs->insert({attr_names[1].getValue(), input_types[0]});
214       return success();
215     }
216 
217     // list(dtype) case
218     if (attr_names.size() == 1) {
219       derived_attrs->insert(
220           {attr_names[0].getValue(), rewriter.getArrayAttr(input_types)});
221       return success();
222     }
223   }
224 
225   return failure();
226 }
227 
CollectInputsAndAttributes(PatternRewriter & rewriter,TFRFuncOp signature,CallOp call_op,SmallVectorImpl<Value> * inputs,NamedAttrList * arg_attrs,llvm::StringMap<Attribute> * derived_attrs) const228 LogicalResult RewriteTFRCallOp::CollectInputsAndAttributes(
229     PatternRewriter& rewriter, TFRFuncOp signature, CallOp call_op,
230     SmallVectorImpl<Value>* inputs, NamedAttrList* arg_attrs,
231     llvm::StringMap<Attribute>* derived_attrs) const {
232   for (const auto& operand :
233        llvm::enumerate(signature.getFunctionType().getInputs())) {
234     // If the index is larger than the operand number of the call_op, the
235     // default value of the operand needs to be used.
236     if (operand.index() >= call_op.getNumOperands()) {
237       auto attr_name = signature.getArgAttrOfType<StringAttr>(
238           operand.index(), kAttrArgumentNameAttr);
239       auto attr_value =
240           signature.getArgAttr(operand.index(), kAttrArgumentDefaultAttr);
241       arg_attrs->push_back(
242           rewriter.getNamedAttr(attr_name.getValue(), attr_value));
243       continue;
244     }
245 
246     // The index is valid for the call_op.
247     Value input = call_op.getOperand(operand.index());
248     Operation* input_op = input.getDefiningOp();
249     auto input_tfr_type =
250         signature.getFunctionType().getInputs()[operand.index()];
251 
252     // There are three cases for the preceding input_op:
253 
254     // 1. The preceding op can be a tfr.cast op, which will be fused to the
255     // current op, so the result op has input with tensor type.
256     if (auto cast_op = dyn_cast_or_null<CastOp>(input_op)) {
257       Value input_to_cast = CastToNonDerivedType(rewriter, call_op.getLoc(),
258                                                  cast_op, input_tfr_type);
259       inputs->push_back(input_to_cast);
260       if (failed(AddDerivedAttrs(rewriter, input_tfr_type,
261                                  {cast_op.getInputElementType()},
262                                  derived_attrs))) {
263         return failure();
264       }
265       continue;
266     }
267 
268     // 2. The preceding op is a tfr.build_list op, which collects multiple
269     // values with tensor types via the tfr.cast ops. These ops will be fused
270     // to the current op as well, so all the tfr.cast op inputs will be inputs
271     // to the result op.
272     if (auto list_op = dyn_cast_or_null<BuildListOp>(input_op)) {
273       // Find out all the inputs to the build list op
274       // TODO(fengliuai): make build_list op only take tensor argument
275       llvm::SmallVector<Attribute, 4> list_input_types;
276       llvm::SmallVector<Value, 4> list_inputs;
277       for (auto list_input : list_op.getOperands()) {
278         auto cast_op = dyn_cast_or_null<CastOp>(list_input.getDefiningOp());
279         if (!cast_op) return failure();
280         list_inputs.push_back(cast_op.arg());
281         list_input_types.push_back(cast_op.getInputElementType());
282       }
283       CastValuesToSameType(rewriter, call_op.getLoc(), list_input_types,
284                            list_inputs);
285       inputs->append(list_inputs.begin(), list_inputs.end());
286       if (failed(AddDerivedAttrs(rewriter, input_tfr_type, list_input_types,
287                                  derived_attrs))) {
288         return failure();
289       }
290       continue;
291     }
292 
293     // 3. The preceding op is a constant, thus the value of this constant is
294     // used to create an attribute of the result op, according to the signature.
295     Attribute arg_value;
296     // A failure indicates the argument isn't a constant value, so we should
297     // not use it as an attribute.
298     if (!matchPattern(input, m_Constant(&arg_value))) {
299       return failure();
300     }
301     auto attr_name = signature.getArgAttrOfType<StringAttr>(
302         operand.index(), kAttrArgumentNameAttr);
303     auto attr_type = signature.getArgAttrOfType<StringAttr>(
304         operand.index(), kAttrArgumentTypeAttr);
305     auto value = ProcessAttributeValue(arg_value, attr_type);
306     arg_attrs->push_back(rewriter.getNamedAttr(attr_name.getValue(), value));
307   }
308   return success();
309 }
310 
ProcessAttributeValue(Attribute attr,StringAttr attr_type) const311 Attribute RewriteTFRCallOp::ProcessAttributeValue(Attribute attr,
312                                                   StringAttr attr_type) const {
313   if (!attr_type) return attr;
314 
315   if (attr_type.getValue() == "tensor") {
316     if (auto f = attr.dyn_cast<FloatAttr>()) {
317       RankedTensorType type = RankedTensorType::get({}, f.getType());
318       return DenseFPElementsAttr::get(type, attr);
319     }
320     // TODO(fengliuai): handles ArrayAttr. Note that it can be nested ArrayAttr.
321   }
322 
323   return attr;
324 }
325 
326 // For each output, uses the attribute name associated to the tfr types to find
327 // out the attribute value from the collected `attrs` and create the output type
328 // of the result op by using the attribute value as the element type.
DeriveOutputTypes(Location loc,FunctionType signature,const llvm::StringMap<Attribute> & attrs,SmallVectorImpl<Type> * output_types) const329 LogicalResult RewriteTFRCallOp::DeriveOutputTypes(
330     Location loc, FunctionType signature,
331     const llvm::StringMap<Attribute>& attrs,
332     SmallVectorImpl<Type>* output_types) const {
333   for (auto res : llvm::enumerate(signature.getResults())) {
334     if (auto tensor_type = res.value().dyn_cast<TFRTensorType>()) {
335       // tfr.tensor should only have one attribute attached.
336       auto attr_key = tensor_type.getAttrKeys().front();
337       Builder builder(signature.getContext());
338       if (auto attr = attrs.lookup(attr_key.getValue())) {
339         output_types->push_back(
340             UnrankedTensorType::get(attr.cast<TypeAttr>().getValue()));
341       } else if (Type element_type =
342                      GetFixedElementType(attr_key.getValue(), builder)) {
343         output_types->push_back(UnrankedTensorType::get(element_type));
344       } else {
345         emitError(loc) << "type " << attr_key.getValue()
346                        << " can't be resolved for the signature of the op";
347         return failure();
348       }
349       continue;
350     }
351 
352     if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
353       // There are two cases: N*T or list(dtype)
354       auto attr_keys = list_type.getAttrKeys();
355       // N*T case
356       if (attr_keys.size() == 2) {
357         // The first one is N, and the second one is T
358         int list_size =
359             attrs.lookup(attr_keys[0].getValue()).cast<IntegerAttr>().getInt();
360         Type list_type =
361             attrs.lookup(attr_keys[1].getValue()).cast<TypeAttr>().getValue();
362         for (int i = 0; i < list_size; ++i) {
363           output_types->push_back(UnrankedTensorType::get(list_type));
364         }
365         continue;
366       }
367       // TODO(fengliuai): list(dtype) case
368     }
369     return failure();
370   }
371   return success();
372 }
373 
CreateAndReplaceOp(PatternRewriter & rewriter,CallOp call_op,const SmallVectorImpl<Type> & output_types,const SmallVectorImpl<Value> & inputs,const NamedAttrList & attr_list,const llvm::StringMap<Attribute> & derived_attrs) const374 LogicalResult RewriteTFRCallOp::CreateAndReplaceOp(
375     PatternRewriter& rewriter, CallOp call_op,
376     const SmallVectorImpl<Type>& output_types,
377     const SmallVectorImpl<Value>& inputs, const NamedAttrList& attr_list,
378     const llvm::StringMap<Attribute>& derived_attrs) const {
379   // Create the new op
380   Location loc = call_op.getLoc();
381   rewriter.setInsertionPointAfter(call_op);
382   std::string tf_op_name = GetTFOpName(call_op.callee());
383   OperationState new_state(loc, tf_op_name, inputs, output_types, attr_list);
384   Operation* new_op = rewriter.create(new_state);
385   if (materialize_derived_attrs_) {
386     for (const auto& attr : derived_attrs) {
387       // Add or update the derived attribute with the value. Skip the fixed
388       // element type attributes, in case they are present in the NodeDef.
389       if (!fixed_elt_type_attrs_.contains(attr.first())) {
390         new_op->setAttr(attr.first(), attr.second);
391       }
392     }
393   }
394   // Create the tfr.cast ops on the results and replace the uses of the
395   // original call op.
396   TFRTensorType unconstrainted_type = rewriter.getType<TFRTensorType>();
397   SmallVector<Value, 4> new_results;
398   for (auto res : llvm::enumerate(call_op.getResultTypes())) {
399     Type res_type = res.value();
400     if (res_type.dyn_cast<TFRTensorType>()) {
401       Value new_res = new_op->getResult(res.index());
402       auto casted = rewriter.create<CastOp>(loc, res_type, new_res);
403       new_results.push_back(casted.out());
404     } else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
405       SmallVector<Value, 4> tensor_list;
406       for (int i = res.index(); i < new_op->getNumResults(); i++) {
407         Value new_res = new_op->getResult(i);
408         auto casted =
409             rewriter.create<CastOp>(loc, unconstrainted_type, new_res);
410         tensor_list.push_back(casted.out());
411       }
412       auto list_op = rewriter.create<BuildListOp>(loc, res_type, tensor_list);
413       new_results.push_back(list_op.out());
414     }
415   }
416 
417   // Copy all the allowed attributes to the new op.
418   if (failed(CopyNonSymbolRefAttrs(call_op, new_op))) return failure();
419 
420   rewriter.replaceOp(call_op, new_results);
421   return success();
422 }
423 
matchAndRewrite(CallOp call_op,PatternRewriter & rewriter) const424 LogicalResult RewriteTFRCallOp::matchAndRewrite(
425     CallOp call_op, PatternRewriter& rewriter) const {
426   // Get the func op and verify that it is external. The type of this external
427   // func op is used as the signature of the corresponding TF ops. All the
428   // external func ops have the trailing underscore.
429   std::string external_callee_name = call_op.callee().str().append("_");
430   TFRFuncOp func = symbol_table_.lookup<TFRFuncOp>(external_callee_name);
431   if (!func || !func.isExternal()) return failure();
432   // Get the inputs and attributes. The attributes include these from the
433   // argument list and also these derived from the inputs.
434   SmallVector<Value, 4> inputs;
435   NamedAttrList argument_attrs;
436   llvm::StringMap<Attribute> derived_attrs;
437   if (failed(CollectInputsAndAttributes(rewriter, func, call_op, &inputs,
438                                         &argument_attrs, &derived_attrs))) {
439     return failure();
440   }
441 
442   // Derive the output types. The result type is derived by using the
443   // attributes attched to the result type of the signature. The attribute
444   // value should be either in the attribute argument list or the derived
445   // attribute from the input tensors. All the result type
446   // are unranked, and shape inference should be applied afterwards.
447   SmallVector<Type, 4> output_types;
448 
449   // Merge the attributes from the argument list to the derived ones.
450   for (auto& attr : argument_attrs) {
451     derived_attrs.insert({attr.getName(), attr.getValue()});
452   }
453 
454   // Derive the output types by using the attributes attached to the tfr
455   // types.
456   if (failed(DeriveOutputTypes(call_op->getLoc(), func.getFunctionType(),
457                                derived_attrs, &output_types))) {
458     return failure();
459   }
460 
461   // Create the new op and replace the old TFR call op.
462   return CreateAndReplaceOp(rewriter, call_op, output_types, inputs,
463                             argument_attrs, derived_attrs);
464 }
465 
466 // Raise TFR call ops to the TF ops.
467 class RaiseToTFOpsPass
468     : public PassWrapper<RaiseToTFOpsPass, OperationPass<func::FuncOp>> {
469  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseToTFOpsPass)470   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseToTFOpsPass)
471 
472   void getDependentDialects(DialectRegistry& registry) const override {
473     registry.insert<TFRDialect, TF::TensorFlowDialect, scf::SCFDialect,
474                     arith::ArithmeticDialect, func::FuncDialect>();
475   }
476 
RaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,bool materialize_derived_attrs)477   explicit RaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,
478                             bool materialize_derived_attrs)
479       : external_tfr_module_(tfr_module),
480         materialize_derived_attrs_(materialize_derived_attrs) {}
481 
getArgument() const482   StringRef getArgument() const final { return "tfr-raise-to-tf"; }
483 
getDescription() const484   StringRef getDescription() const final {
485     return "Raise all the TFR call ops to TF ops.";
486   }
487 
488   void runOnOperation() override;
489 
490  private:
491   llvm::Optional<ModuleOp> external_tfr_module_;
492   const bool materialize_derived_attrs_;
493 };
494 
runOnOperation()495 void RaiseToTFOpsPass::runOnOperation() {
496   func::FuncOp func = getOperation();
497   MLIRContext* ctx = &getContext();
498   SymbolTable table(external_tfr_module_.has_value()
499                         ? *external_tfr_module_
500                         : func->getParentOfType<ModuleOp>());
501 
502   RewritePatternSet patterns(&getContext());
503   patterns.add<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs_);
504 
505   populateCanonicalizationPatterns(func, patterns);
506 
507   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
508 }
509 }  // namespace
510 
511 // Creates an instance of the pass to raise TFR call ops to the TF ops.
CreateRaiseToTFOpsPass(llvm::Optional<ModuleOp> tfr_module,bool materialize_derived_attrs)512 std::unique_ptr<OperationPass<func::FuncOp>> CreateRaiseToTFOpsPass(
513     llvm::Optional<ModuleOp> tfr_module, bool materialize_derived_attrs) {
514   return std::make_unique<RaiseToTFOpsPass>(tfr_module,
515                                             materialize_derived_attrs);
516 }
517 
__anonf1e22a540202null518 static PassRegistration<RaiseToTFOpsPass> pass([] {
519   return CreateRaiseToTFOpsPass();
520 });
521 
522 }  // namespace TFR
523 }  // namespace mlir
524