1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the standard MLIR TensorFlow dialect after control 17 // dependences are raise to the standard form. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ 20 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ 21 22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 23 #include "mlir/IR/Dialect.h" // from @llvm-project 24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" 25 26 namespace mlir { 27 namespace TF { 28 29 class TensorFlowRegistryEffectInterfaceFallback; 30 31 class TensorFlowDialect final : public Dialect { 32 public: 33 explicit TensorFlowDialect(MLIRContext *context); 34 ~TensorFlowDialect() override; 35 getDialectNamespace()36 static StringRef getDialectNamespace() { return "tf"; } 37 38 // Overrides to redirect to tf_type dialect. 39 Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; 40 Type parseType(DialectAsmParser &parser) const override; 41 42 // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a 43 // function references to its gradient function. This attribute in TensorFlow 44 // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the 45 // string description of gradient attribute. GetGradientAttrName()46 static StringRef GetGradientAttrName() { return "tf.gradient"; } 47 48 // This attribute marks if a function is stateful. 49 // Returns the string description of stateful attribute. GetStatefulAttrName()50 static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; } 51 52 // Returns true if the op can be duplicated during transformations. 53 static bool CanDuplicate(Operation *op); 54 55 // Returns true if the op can have side effects. 56 static bool CanHaveSideEffects(Operation *op); 57 58 // Registered hook to materialize a constant operation from a given attribute 59 // value with the desired resultant type. 60 Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, 61 Location loc) override; 62 63 typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction; 64 65 // Register an op registration hook which is invoked during construction. 66 // 67 // A hook may use the public addOperations() method to add additional 68 // operations to the dialect. Hooks will only apply to subsequent 69 // instantations of the Dialect/MLIRContext. 70 static void RegisterAdditionalOperationHook(TypeID uniqueId, 71 AdditionalOpFunction fn); 72 73 // Re-define publicly the protected addOperations() method from the Dialect 74 // class, usually used in a Dialect constructor. This allows hook 75 // functions to register operations on the TensorFlow dialect using the 76 // same interface. 77 template <typename... Args> addOperations()78 void addOperations() { 79 Dialect::addOperations<Args...>(); 80 } 81 82 using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>, 83 SmallVectorImpl<OpFoldResult> &); RegisterConstantFoldHook(ConstantFoldHook fn)84 static void RegisterConstantFoldHook(ConstantFoldHook fn) { 85 constant_fold_hook_ = std::move(fn); 86 } 87 constantFold(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)88 static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands, 89 SmallVectorImpl<OpFoldResult> &results) { 90 if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); 91 return failure(); 92 } 93 94 // Provides a hook for op interface. 95 void *getRegisteredInterfaceForOp(mlir::TypeID interface, 96 mlir::OperationName opName) override; 97 98 private: 99 static ConstantFoldHook constant_fold_hook_; 100 101 // Storage for a custom fallback interface. 102 TensorFlowRegistryEffectInterfaceFallback *fallback_effect_op_interface_; 103 }; 104 105 } // namespace TF 106 } // namespace mlir 107 108 #define TF_DIALECT_REGISTER_ADDITIONAL_OPERATIONS(hookFn) \ 109 { \ 110 static bool key; \ 111 ::mlir::TF::TensorFlowDialect::RegisterAdditionalOperationHook( \ 112 ::mlir::TypeID::getFromOpaquePointer(&key), hookFn); \ 113 } 114 115 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ 116