xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the standard MLIR TensorFlow dialect after control
17 // dependences are raise to the standard form.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
20 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
21 
22 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
23 #include "mlir/IR/Dialect.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
25 
26 namespace mlir {
27 namespace TF {
28 
29 class TensorFlowRegistryEffectInterfaceFallback;
30 
31 class TensorFlowDialect final : public Dialect {
32  public:
33   explicit TensorFlowDialect(MLIRContext *context);
34   ~TensorFlowDialect() override;
35 
getDialectNamespace()36   static StringRef getDialectNamespace() { return "tf"; }
37 
38   // Overrides to redirect to tf_type dialect.
39   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
40   Type parseType(DialectAsmParser &parser) const override;
41 
42   // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
43   // function references to its gradient function. This attribute in TensorFlow
44   // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
45   // string description of gradient attribute.
GetGradientAttrName()46   static StringRef GetGradientAttrName() { return "tf.gradient"; }
47 
48   // This attribute marks if a function is stateful.
49   // Returns the string description of stateful attribute.
GetStatefulAttrName()50   static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; }
51 
52   // Returns true if the op can be duplicated during transformations.
53   static bool CanDuplicate(Operation *op);
54 
55   // Returns true if the op can have side effects.
56   static bool CanHaveSideEffects(Operation *op);
57 
58   // Registered hook to materialize a constant operation from a given attribute
59   // value with the desired resultant type.
60   Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
61                                  Location loc) override;
62 
63   typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
64 
65   // Register an op registration hook which is invoked during construction.
66   //
67   // A hook may use the public addOperations() method to add additional
68   // operations to the dialect. Hooks will only apply to subsequent
69   // instantations of the Dialect/MLIRContext.
70   static void RegisterAdditionalOperationHook(TypeID uniqueId,
71                                               AdditionalOpFunction fn);
72 
73   // Re-define publicly the protected addOperations() method from the Dialect
74   // class, usually used in a Dialect constructor. This allows hook
75   // functions to register operations on the TensorFlow dialect using the
76   // same interface.
77   template <typename... Args>
addOperations()78   void addOperations() {
79     Dialect::addOperations<Args...>();
80   }
81 
82   using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
83                                              SmallVectorImpl<OpFoldResult> &);
RegisterConstantFoldHook(ConstantFoldHook fn)84   static void RegisterConstantFoldHook(ConstantFoldHook fn) {
85     constant_fold_hook_ = std::move(fn);
86   }
87 
constantFold(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)88   static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
89                                     SmallVectorImpl<OpFoldResult> &results) {
90     if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
91     return failure();
92   }
93 
94   // Provides a hook for op interface.
95   void *getRegisteredInterfaceForOp(mlir::TypeID interface,
96                                     mlir::OperationName opName) override;
97 
98  private:
99   static ConstantFoldHook constant_fold_hook_;
100 
101   // Storage for a custom fallback interface.
102   TensorFlowRegistryEffectInterfaceFallback *fallback_effect_op_interface_;
103 };
104 
105 }  // namespace TF
106 }  // namespace mlir
107 
108 #define TF_DIALECT_REGISTER_ADDITIONAL_OPERATIONS(hookFn)           \
109   {                                                                 \
110     static bool key;                                                \
111     ::mlir::TF::TensorFlowDialect::RegisterAdditionalOperationHook( \
112         ::mlir::TypeID::getFromOpaquePointer(&key), hookFn);        \
113   }
114 
115 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
116