xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 #include <numeric>
23 #include <string>
24 #include <tuple>
25 #include <type_traits>
26 #include <utility>
27 
28 #include "absl/strings/str_cat.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/Sequence.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/StringRef.h"
38 #include "llvm/ADT/StringSwitch.h"
39 #include "llvm/ADT/iterator_range.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
43 #include "mlir/Dialect/Traits.h"  // from @llvm-project
44 #include "mlir/IR/Attributes.h"  // from @llvm-project
45 #include "mlir/IR/Builders.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
49 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
50 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
51 #include "mlir/IR/Location.h"  // from @llvm-project
52 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
53 #include "mlir/IR/Matchers.h"  // from @llvm-project
54 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
55 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
56 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
57 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
58 #include "mlir/IR/Types.h"  // from @llvm-project
59 #include "mlir/IR/Value.h"  // from @llvm-project
60 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
61 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
62 #include "mlir/Parser/Parser.h"  // from @llvm-project
63 #include "mlir/Support/LLVM.h"  // from @llvm-project
64 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
65 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
70 #include "tensorflow/core/common_runtime/inline_function_utils.h"
71 #include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h"
72 #include "tensorflow/core/framework/op.h"
73 #include "tensorflow/core/framework/op_def_builder.h"
74 #include "tensorflow/core/platform/logging.h"
75 #include "tensorflow/core/util/device_name_utils.h"
76 #include "tensorflow/core/util/tensor_format.h"
77 
78 namespace mlir {
79 namespace TF {
80 
81 //===----------------------------------------------------------------------===//
82 // TF Dialect Interfaces
83 //===----------------------------------------------------------------------===//
84 
85 namespace {
86 
87 struct TFConstantFoldInterface : public DialectFoldInterface {
TFConstantFoldInterfacemlir::TF::__anon018ee1e30111::TFConstantFoldInterface88   TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {}
foldmlir::TF::__anon018ee1e30111::TFConstantFoldInterface89   LogicalResult fold(Operation *op, ArrayRef<Attribute> operands,
90                      SmallVectorImpl<OpFoldResult> &results) const final {
91     return TensorFlowDialect::constantFold(op, operands, results);
92   }
93 };
94 
95 // Helper function that implements the multi-device inlining policy behavior
96 // for the inliner hook. In particular, for all function body nodes set unset
97 // placement attributes to match the function call node.
MultiDeviceProcessInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks)98 void MultiDeviceProcessInlinedCallBlocks(
99     Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
100   using DeviceNameUtils = tensorflow::DeviceNameUtils;
101 
102   // Duplicate of the logic in MultiDeviceFunctionBodyPlacer::BodyNodeDevice
103   // LINT.IfChange
104   auto device_id = StringAttr::get(call->getContext(), "device");
105   auto caller_device = call->getAttrOfType<StringAttr>(device_id);
106   if (!caller_device) return;
107 
108   DeviceNameUtils::ParsedName caller_parsed_device;
109   if (!DeviceNameUtils::ParseFullName(caller_device.getValue().str(),
110                                       &caller_parsed_device))
111     return;
112 
113   MLIRContext *context = call->getContext();
114   auto node_device = [&](Operation *n) -> StringAttr {
115     auto device = n->getAttrOfType<StringAttr>(device_id);
116     if (!device || device.getValue().empty()) return caller_device;
117 
118     DeviceNameUtils::ParsedName ndef_parsed_device;
119     if (!DeviceNameUtils::ParseFullName(device.getValue().str(),
120                                         &ndef_parsed_device))
121       return device;
122     DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device,
123                                         caller_parsed_device);
124     return StringAttr::get(
125         context, DeviceNameUtils::ParsedNameToString(ndef_parsed_device));
126   };
127   // LINT.ThenChange(../../../../core/common_runtime/inline_function_utils.cc)
128 
129   for (Block &block : inlinedBlocks) {
130     block.walk([&](Operation *op) {
131       if (op->getDialect() == call->getDialect())
132         op->setAttr(device_id, node_device(op));
133     });
134   }
135 }
136 
137 struct TFInlinerInterface : public DialectInlinerInterface {
138   using DialectInlinerInterface::DialectInlinerInterface;
139 
140   //===--------------------------------------------------------------------===//
141   // Analysis Hooks
142   //===--------------------------------------------------------------------===//
143 
144   // Returns if it's legal to inline 'callable' into the 'call', where 'call' is
145   // a TF operation.
isLegalToInlinemlir::TF::__anon018ee1e30111::TFInlinerInterface146   bool isLegalToInline(Operation *call, Operation *callable,
147                        bool wouldBeCloned) const final {
148     // Skip inlining for TPUPartitionedCalls.
149     if (isa<TPUPartitionedCallOp>(call)) return false;
150     // Maintain inlining for  `tf.function`s with jit_compile option.
151     if (callable->hasAttr("tf._XlaMustCompile")) return true;
152     auto noinline_attr_name = absl::StrCat("tf.", tensorflow::kNoInlineAttr);
153     if (auto noinline_attr =
154             callable->getAttrOfType<BoolAttr>(noinline_attr_name))
155       return !noinline_attr.getValue();
156     return true;
157   }
158 
159   // Returns if its legal to inline 'src' region into the 'dest' region
160   // attached to a TF operation.
isLegalToInlinemlir::TF::__anon018ee1e30111::TFInlinerInterface161   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162                        BlockAndValueMapping &valueMapping) const final {
163     // Allow inlining in regions attached to region based control flow
164     // operations only if the src region is a single block region
165     return isa<IfRegionOp, WhileRegionOp>(dest->getParentOp()) &&
166            llvm::hasSingleElement(*src);
167   }
168 
169   // Returns true if its legal to inline a TF operation `op` into the `dest`
170   // region.
isLegalToInlinemlir::TF::__anon018ee1e30111::TFInlinerInterface171   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
172                        BlockAndValueMapping &) const final {
173     // An op is legal to inline if either of the following conditions is true:
174     // (a) Its legal to duplicate the Op.
175     // (b) The Op is inside a single use function. If that function is inlined,
176     //     post inlining, the function will be dead and eliminated from the IR.
177     //     So there won't be any code duplication.
178     // plus the function caller op can be replaced by inlined ops.
179     return !wouldBeCloned || TensorFlowDialect::CanDuplicate(op);
180   }
181 
182   //===--------------------------------------------------------------------===//
183   // Transformation Hooks
184   //===--------------------------------------------------------------------===//
185 
186   // Attempts to materialize a conversion for a type mismatch between a call
187   // from this dialect, and a callable region. This method should generate an
188   // operation that takes 'input' as the only operand, and produces a single
189   // result of 'resultType'. If a conversion can not be generated, nullptr
190   // should be returned.
materializeCallConversionmlir::TF::__anon018ee1e30111::TFInlinerInterface191   Operation *materializeCallConversion(OpBuilder &builder, Value input,
192                                        Type result_type,
193                                        Location conversion_loc) const final {
194     if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
195       return nullptr;
196     return builder.create<TF::CastOp>(conversion_loc, result_type, input,
197                                       /*truncate=*/builder.getBoolAttr(false));
198   }
199 
processInlinedCallBlocksmlir::TF::__anon018ee1e30111::TFInlinerInterface200   void processInlinedCallBlocks(
201       Operation *call,
202       iterator_range<Region::iterator> inlinedBlocks) const final {
203     bool has_lower_as_multi_device_function_attr = false;
204     if (auto lower = call->getAttrOfType<BoolAttr>(
205             tensorflow::LowerFunctionalOpsConstants::
206                 kLowerAsMultiDeviceFunctionAttr))
207       has_lower_as_multi_device_function_attr = lower.getValue();
208     tensorflow::FunctionCallInlinePolicy policy =
209         tensorflow::GetFunctionCallInlinePolicy(
210             isa<PartitionedCallOp, StatefulPartitionedCallOp>(call),
211             has_lower_as_multi_device_function_attr);
212 
213     if (policy == tensorflow::FunctionCallInlinePolicy::kMultiDevicePlacer)
214       return MultiDeviceProcessInlinedCallBlocks(call, inlinedBlocks);
215   }
216 };
217 }  // end anonymous namespace
218 
219 //===----------------------------------------------------------------------===//
220 // TF Dialect
221 //===----------------------------------------------------------------------===//
222 
223 // Returns true if the op can be duplicated.
CanDuplicate(Operation * op)224 bool TensorFlowDialect::CanDuplicate(Operation *op) {
225   // If the op is marked with the cannot duplicate trait, it cannot be
226   // duplicated.
227   if (op->hasTrait<OpTrait::TF::CannotDuplicate>()) return false;
228 
229   // If the op has no memory side effects, it can be duplicated.
230   if (MemoryEffectOpInterface::hasNoEffect(op)) return true;
231 
232   // If the op is marked stateless using the `is_stateless` attribute, that
233   // attribute determines if the op can be duplicated.
234   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
235     return is_stateless.getValue();
236 
237   // Assume ops can be duplicated if modelled.
238   return op->isRegistered();
239 }
240 
241 // TF dialect fallback for MemoryEffectOpInterface. The filtering for returning
242 // the interface is done in the return below and here it is empty as it is only
243 // returned for known not-stateful and unmodelled ops.
244 struct TensorFlowRegistryEffectInterfaceFallback
245     : public MemoryEffectOpInterface::FallbackModel<
246           TensorFlowRegistryEffectInterfaceFallback> {
classofmlir::TF::TensorFlowRegistryEffectInterfaceFallback247   static bool classof(Operation *op) { return true; }
getEffectsmlir::TF::TensorFlowRegistryEffectInterfaceFallback248   void getEffects(
249       Operation *op,
250       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
251           &effects) const {}
252 };
253 
getRegisteredInterfaceForOp(mlir::TypeID interface,mlir::OperationName opName)254 void *TensorFlowDialect::getRegisteredInterfaceForOp(
255     mlir::TypeID interface, mlir::OperationName opName) {
256   if (interface == TypeID::get<mlir::MemoryEffectOpInterface>()) {
257     // Don't use fallback for modelled ops.
258     if (opName.isRegistered()) return nullptr;
259 
260     // Only use fallback interface for known not-stateful ops.
261     const tensorflow::OpRegistrationData *op_reg_data = nullptr;
262     tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp(
263         opName.stripDialect().str(), &op_reg_data);
264     return (s.ok() && !op_reg_data->op_def.is_stateful())
265                ? fallback_effect_op_interface_
266                : nullptr;
267   }
268 
269   return nullptr;
270 }
271 
272 // Returns true if the op can have side effects.
CanHaveSideEffects(Operation * op)273 bool TensorFlowDialect::CanHaveSideEffects(Operation *op) {
274   // If the op has no memory side effects, it has no side effects
275   if (MemoryEffectOpInterface::hasNoEffect(op)) return false;
276 
277   // If the op is marked stateless using the `is_stateless` attribute, then
278   // it has no side effects.
279   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
280     return !is_stateless.getValue();
281 
282   // Terminators defined in the TF dialect do not have side effects.
283   if (op->hasTrait<OpTrait::IsTerminator>()) return false;
284 
285   // Otherwise assume that the op can have side effects.
286   return true;
287 }
288 
289 // Hook functions which may add additional operations to the dialect.
290 // These are invoked at construction time.
291 static DenseMap<TypeID, TensorFlowDialect::AdditionalOpFunction>
GetAdditionalOperationHooks()292     &GetAdditionalOperationHooks() {
293   static auto *additional_operation_hooks =
294       new DenseMap<TypeID, TensorFlowDialect::AdditionalOpFunction>();
295   return *additional_operation_hooks;
296 }
297 
RegisterAdditionalOperationHook(TypeID id,AdditionalOpFunction fn)298 void TensorFlowDialect::RegisterAdditionalOperationHook(
299     TypeID id, AdditionalOpFunction fn) {
300   GetAdditionalOperationHooks().try_emplace(id, std::move(fn));
301 }
302 
303 TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_;
304 
TensorFlowDialect(MLIRContext * context)305 TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
306     : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>()) {
307   context->getOrLoadDialect<::mlir::tf_type::TFTypeDialect>();
308   addOperations<
309 #define GET_OP_LIST
310 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc"
311       >();
312   addOperations<
313 #define GET_OP_LIST
314 #include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
315       >();
316   addInterfaces<TFInlinerInterface, TFConstantFoldInterface>();
317   fallback_effect_op_interface_ =
318       new TensorFlowRegistryEffectInterfaceFallback();
319 
320   // Support unknown operations because not all TensorFlow operations are
321   // registered.
322   allowUnknownOperations();
323 
324   for (auto &hook : GetAdditionalOperationHooks()) {
325     hook.second(*this);
326   }
327 }
328 
~TensorFlowDialect()329 TensorFlowDialect::~TensorFlowDialect() {
330   delete fallback_effect_op_interface_;
331 }
332 
parseType(DialectAsmParser & parser) const333 Type TensorFlowDialect::parseType(DialectAsmParser &parser) const {
334   StringRef spec = parser.getFullSymbolSpec();
335   llvm::SMLoc loc = parser.getCurrentLocation();
336   parser.emitError(
337       loc, "tf dialect has no types, potentially meant !tf_type." + spec);
338   return nullptr;
339 }
340 
parseAttribute(DialectAsmParser & parser,Type type) const341 Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
342                                             Type type) const {
343   StringRef spec = parser.getFullSymbolSpec();
344   llvm::SMLoc loc = parser.getCurrentLocation();
345   parser.emitError(
346       loc, "tf dialect has no attributes, potentially meant #tf_type." + spec);
347   return nullptr;
348 }
349 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)350 Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder,
351                                                   Attribute value, Type type,
352                                                   Location loc) {
353   return builder.create<ConstOp>(loc, type, value);
354 }
355 
356 }  // namespace TF
357 }  // namespace mlir
358