1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 #include <utility>
27
28 #include "absl/strings/str_cat.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/Sequence.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringRef.h"
38 #include "llvm/ADT/StringSwitch.h"
39 #include "llvm/ADT/iterator_range.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
43 #include "mlir/Dialect/Traits.h" // from @llvm-project
44 #include "mlir/IR/Attributes.h" // from @llvm-project
45 #include "mlir/IR/Builders.h" // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
49 #include "mlir/IR/Diagnostics.h" // from @llvm-project
50 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
51 #include "mlir/IR/Location.h" // from @llvm-project
52 #include "mlir/IR/MLIRContext.h" // from @llvm-project
53 #include "mlir/IR/Matchers.h" // from @llvm-project
54 #include "mlir/IR/OpDefinition.h" // from @llvm-project
55 #include "mlir/IR/OpImplementation.h" // from @llvm-project
56 #include "mlir/IR/PatternMatch.h" // from @llvm-project
57 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
58 #include "mlir/IR/Types.h" // from @llvm-project
59 #include "mlir/IR/Value.h" // from @llvm-project
60 #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
61 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
62 #include "mlir/Parser/Parser.h" // from @llvm-project
63 #include "mlir/Support/LLVM.h" // from @llvm-project
64 #include "mlir/Support/LogicalResult.h" // from @llvm-project
65 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/core/common_runtime/inline_function_utils.h"
71 #include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h"
72 #include "tensorflow/core/framework/op.h"
73 #include "tensorflow/core/framework/op_def_builder.h"
74 #include "tensorflow/core/platform/logging.h"
75 #include "tensorflow/core/util/device_name_utils.h"
76 #include "tensorflow/core/util/tensor_format.h"
77
78 namespace mlir {
79 namespace TF {
80
81 //===----------------------------------------------------------------------===//
82 // TF Dialect Interfaces
83 //===----------------------------------------------------------------------===//
84
85 namespace {
86
87 struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterfacemlir::TF::__anon018ee1e30111::TFConstantFoldInterface88 TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
foldmlir::TF::__anon018ee1e30111::TFConstantFoldInterface89 LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
90 SmallVectorImpl<OpFoldResult> &results) const final {
91 return TensorFlowDialect::constantFold(op, operands, results);
92 }
93 };
94
95 // Helper function that implements the multi-device inlining policy behavior
96 // for the inliner hook. In particular, for all function body nodes set unset
97 // placement attributes to match the function call node.
MultiDeviceProcessInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks)98 void MultiDeviceProcessInlinedCallBlocks(
99 Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
100 using DeviceNameUtils = tensorflow::DeviceNameUtils;
101
102 // Duplicate of the logic in MultiDeviceFunctionBodyPlacer::BodyNodeDevice
103 // LINT.IfChange
104 auto device_id = StringAttr::get(call->getContext(), "device");
105 auto caller_device = call->getAttrOfType<StringAttr>(device_id);
106 if (!caller_device) return;
107
108 DeviceNameUtils::ParsedName caller_parsed_device;
109 if (!DeviceNameUtils::ParseFullName(caller_device.getValue().str(),
110 &caller_parsed_device))
111 return;
112
113 MLIRContext *context = call->getContext();
114 auto node_device = [&](Operation *n) -> StringAttr {
115 auto device = n->getAttrOfType<StringAttr>(device_id);
116 if (!device || device.getValue().empty()) return caller_device;
117
118 DeviceNameUtils::ParsedName ndef_parsed_device;
119 if (!DeviceNameUtils::ParseFullName(device.getValue().str(),
120 &ndef_parsed_device))
121 return device;
122 DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device,
123 caller_parsed_device);
124 return StringAttr::get(
125 context, DeviceNameUtils::ParsedNameToString(ndef_parsed_device));
126 };
127 // LINT.ThenChange(../../../../core/common_runtime/inline_function_utils.cc)
128
129 for (Block &block : inlinedBlocks) {
130 block.walk([&](Operation *op) {
131 if (op->getDialect() == call->getDialect())
132 op->setAttr(device_id, node_device(op));
133 });
134 }
135 }
136
137 struct TFInlinerInterface : public DialectInlinerInterface {
138 using DialectInlinerInterface::DialectInlinerInterface;
139
140 //===--------------------------------------------------------------------===//
141 // Analysis Hooks
142 //===--------------------------------------------------------------------===//
143
144 // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
145 // a TF operation.
isLegalToInlinemlir::TF::__anon018ee1e30111::TFInlinerInterface146 bool isLegalToInline(Operation *call, Operation *callable,
147 bool wouldBeCloned) const final {
148 // Skip inlining for TPUPartitionedCalls.
149 if (isa<TPUPartitionedCallOp>(call)) return false;
150 // Maintain inlining for `tf.function`s with jit_compile option.
151 if (callable->hasAttr("tf._XlaMustCompile")) return true;
152 auto noinline_attr_name = absl::StrCat("tf.", tensorflow::kNoInlineAttr);
153 if (auto noinline_attr =
154 callable->getAttrOfType<BoolAttr>(noinline_attr_name))
155 return !noinline_attr.getValue();
156 return true;
157 }
158
159 // Returns if its legal to inline 'src' region into the 'dest' region
160 // attached to a TF operation.
isLegalToInlinemlir::TF::__anon018ee1e30111::TFInlinerInterface161 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162 BlockAndValueMapping &valueMapping) const final {
163 // Allow inlining in regions attached to region based control flow
164 // operations only if the src region is a single block region
165 return isa<IfRegionOp, WhileRegionOp>(dest->getParentOp()) &&
166 llvm::hasSingleElement(*src);
167 }
168
169 // Returns true if its legal to inline a TF operation `op` into the `dest`
170 // region.
isLegalToInlinemlir::TF::__anon018ee1e30111::TFInlinerInterface171 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
172 BlockAndValueMapping &) const final {
173 // An op is legal to inline if either of the following conditions is true:
174 // (a) Its legal to duplicate the Op.
175 // (b) The Op is inside a single use function. If that function is inlined,
176 // post inlining, the function will be dead and eliminated from the IR.
177 // So there won't be any code duplication.
178 // plus the function caller op can be replaced by inlined ops.
179 return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
180 }
181
182 //===--------------------------------------------------------------------===//
183 // Transformation Hooks
184 //===--------------------------------------------------------------------===//
185
186 // Attempts to materialize a conversion for a type mismatch between a call
187 // from this dialect, and a callable region. This method should generate an
188 // operation that takes 'input' as the only operand, and produces a single
189 // result of 'resultType'. If a conversion can not be generated, nullptr
190 // should be returned.
materializeCallConversionmlir::TF::__anon018ee1e30111::TFInlinerInterface191 Operation *materializeCallConversion(OpBuilder &builder, Value input,
192 Type result_type,
193 Location conversion_loc) const final {
194 if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
195 return nullptr;
196 return builder.create<TF::CastOp>(conversion_loc, result_type, input,
197 /*truncate=*/builder.getBoolAttr(false));
198 }
199
processInlinedCallBlocksmlir::TF::__anon018ee1e30111::TFInlinerInterface200 void processInlinedCallBlocks(
201 Operation *call,
202 iterator_range<Region::iterator> inlinedBlocks) const final {
203 bool has_lower_as_multi_device_function_attr = false;
204 if (auto lower = call->getAttrOfType<BoolAttr>(
205 tensorflow::LowerFunctionalOpsConstants::
206 kLowerAsMultiDeviceFunctionAttr))
207 has_lower_as_multi_device_function_attr = lower.getValue();
208 tensorflow::FunctionCallInlinePolicy policy =
209 tensorflow::GetFunctionCallInlinePolicy(
210 isa<PartitionedCallOp, StatefulPartitionedCallOp>(call),
211 has_lower_as_multi_device_function_attr);
212
213 if (policy == tensorflow::FunctionCallInlinePolicy::kMultiDevicePlacer)
214 return MultiDeviceProcessInlinedCallBlocks(call, inlinedBlocks);
215 }
216 };
217 } // end anonymous namespace
218
219 //===----------------------------------------------------------------------===//
220 // TF Dialect
221 //===----------------------------------------------------------------------===//
222
223 // Returns true if the op can be duplicated.
CanDuplicate(Operation * op)224 bool TensorFlowDialect::CanDuplicate(Operation *op) {
225 // If the op is marked with the cannot duplicate trait, it cannot be
226 // duplicated.
227 if (op->hasTrait<OpTrait::TF::CannotDuplicate>()) return false;
228
229 // If the op has no memory side effects, it can be duplicated.
230 if (MemoryEffectOpInterface::hasNoEffect(op)) return true;
231
232 // If the op is marked stateless using the `is_stateless` attribute, that
233 // attribute determines if the op can be duplicated.
234 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
235 return is_stateless.getValue();
236
237 // Assume ops can be duplicated if modelled.
238 return op->isRegistered();
239 }
240
241 // TF dialect fallback for MemoryEffectOpInterface. The filtering for returning
242 // the interface is done in the return below and here it is empty as it is only
243 // returned for known not-stateful and unmodelled ops.
244 struct TensorFlowRegistryEffectInterfaceFallback
245 : public MemoryEffectOpInterface::FallbackModel<
246 TensorFlowRegistryEffectInterfaceFallback> {
classofmlir::TF::TensorFlowRegistryEffectInterfaceFallback247 static bool classof(Operation *op) { return true; }
getEffectsmlir::TF::TensorFlowRegistryEffectInterfaceFallback248 void getEffects(
249 Operation *op,
250 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
251 &effects) const {}
252 };
253
getRegisteredInterfaceForOp(mlir::TypeID interface,mlir::OperationName opName)254 void *TensorFlowDialect::getRegisteredInterfaceForOp(
255 mlir::TypeID interface, mlir::OperationName opName) {
256 if (interface == TypeID::get<mlir::MemoryEffectOpInterface>()) {
257 // Don't use fallback for modelled ops.
258 if (opName.isRegistered()) return nullptr;
259
260 // Only use fallback interface for known not-stateful ops.
261 const tensorflow::OpRegistrationData *op_reg_data = nullptr;
262 tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp(
263 opName.stripDialect().str(), &op_reg_data);
264 return (s.ok() && !op_reg_data->op_def.is_stateful())
265 ? fallback_effect_op_interface_
266 : nullptr;
267 }
268
269 return nullptr;
270 }
271
272 // Returns true if the op can have side effects.
CanHaveSideEffects(Operation * op)273 bool TensorFlowDialect::CanHaveSideEffects(Operation *op) {
274 // If the op has no memory side effects, it has no side effects
275 if (MemoryEffectOpInterface::hasNoEffect(op)) return false;
276
277 // If the op is marked stateless using the `is_stateless` attribute, then
278 // it has no side effects.
279 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
280 return !is_stateless.getValue();
281
282 // Terminators defined in the TF dialect do not have side effects.
283 if (op->hasTrait<OpTrait::IsTerminator>()) return false;
284
285 // Otherwise assume that the op can have side effects.
286 return true;
287 }
288
289 // Hook functions which may add additional operations to the dialect.
290 // These are invoked at construction time.
291 static DenseMap<TypeID, TensorFlowDialect::AdditionalOpFunction>
GetAdditionalOperationHooks()292 &GetAdditionalOperationHooks() {
293 static auto *additional_operation_hooks =
294 new DenseMap<TypeID, TensorFlowDialect::AdditionalOpFunction>();
295 return *additional_operation_hooks;
296 }
297
RegisterAdditionalOperationHook(TypeID id,AdditionalOpFunction fn)298 void TensorFlowDialect::RegisterAdditionalOperationHook(
299 TypeID id, AdditionalOpFunction fn) {
300 GetAdditionalOperationHooks().try_emplace(id, std::move(fn));
301 }
302
303 TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
304
TensorFlowDialect(MLIRContext * context)305 TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
306 : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
307 context->getOrLoadDialect<::mlir::tf_type::TFTypeDialect>();
308 addOperations<
309 #define GET_OP_LIST
310 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc"
311 >();
312 addOperations<
313 #define GET_OP_LIST
314 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
315 >();
316 addInterfaces<TFInlinerInterface, TFConstantFoldInterface>();
317 fallback_effect_op_interface_ =
318 new TensorFlowRegistryEffectInterfaceFallback();
319
320 // Support unknown operations because not all TensorFlow operations are
321 // registered.
322 allowUnknownOperations();
323
324 for (auto &hook : GetAdditionalOperationHooks()) {
325 hook.second(*this);
326 }
327 }
328
~TensorFlowDialect()329 TensorFlowDialect::~TensorFlowDialect() {
330 delete fallback_effect_op_interface_;
331 }
332
parseType(DialectAsmParser & parser) const333 Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
334 StringRef spec = parser.getFullSymbolSpec();
335 llvm::SMLoc loc = parser.getCurrentLocation();
336 parser.emitError(
337 loc, "tf dialect has no types, potentially meant !tf_type." + spec);
338 return nullptr;
339 }
340
parseAttribute(DialectAsmParser & parser,Type type) const341 Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
342 Type type) const {
343 StringRef spec = parser.getFullSymbolSpec();
344 llvm::SMLoc loc = parser.getCurrentLocation();
345 parser.emitError(
346 loc, "tf dialect has no attributes, potentially meant #tf_type." + spec);
347 return nullptr;
348 }
349
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)350 Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
351 Attribute value, Type type,
352 Location loc) {
353 return builder.create<ConstOp>(loc, type, value);
354 }
355
356 } // namespace TF
357 } // namespace mlir
358