1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <memory>
18
19 #include "absl/strings/str_cat.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Location.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/OperationSupport.h" // from @llvm-project
31 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
32 #include "mlir/Pass/PassManager.h" // from @llvm-project
33 #include "mlir/Support/LLVM.h" // from @llvm-project
34 #include "tensorflow/c/c_api.h"
35 #include "tensorflow/c/eager/abstract_context.h"
36 #include "tensorflow/c/eager/abstract_operation.h"
37 #include "tensorflow/c/eager/abstract_tensor_handle.h"
38 #include "tensorflow/c/eager/c_api.h"
39 #include "tensorflow/c/eager/c_api_internal.h"
40 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
41 #include "tensorflow/c/tensor_interface.h"
42 #include "tensorflow/c/tf_status.h"
43 #include "tensorflow/c/tf_status_helper.h"
44 #include "tensorflow/c/tf_status_internal.h"
45 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
55 #include "tensorflow/core/framework/node_def_util.h"
56 #include "tensorflow/core/framework/tensor_shape.h"
57 #include "tensorflow/core/framework/types.pb.h"
58 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
59 #include "tensorflow/core/platform/errors.h"
60
61 namespace mlir {
62 namespace TF {
63 using tensorflow::AbstractFunction;
64 using tensorflow::AbstractOperation;
65 using tensorflow::AbstractTensorHandle;
66 using tensorflow::AbstractTensorInterface;
67 using tensorflow::dyn_cast;
68 using tensorflow::OutputList;
69 using tensorflow::string;
70 using tensorflow::errors::FailedPrecondition;
71 using tensorflow::errors::InvalidArgument;
72 using tensorflow::errors::Unimplemented;
73 using tensorflow::tracing::TracingContext;
74 using tensorflow::tracing::TracingOperation;
75 using tensorflow::tracing::TracingTensorHandle;
76
77 namespace {
78
RegisterDialects(mlir::MLIRContext & ctx)79 void RegisterDialects(mlir::MLIRContext& ctx) {
80 mlir::DialectRegistry registry;
81 mlir::RegisterAllTensorFlowDialects(registry);
82 ctx.appendDialectRegistry(registry);
83 ctx.loadAllAvailableDialects();
84 }
85
ConvertDataTypeToTensor(tensorflow::DataType dtype,Builder builder,Type * type)86 Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
87 Type* type) {
88 Status s = tensorflow::ConvertDataType(dtype, builder, type);
89 if (s.ok()) *type = UnrankedTensorType::get(*type);
90 return s;
91 }
92
93 class MlirTensor : public TracingTensorHandle {
94 public:
MlirTensor(Value value)95 explicit MlirTensor(Value value)
96 : TracingTensorHandle(kMlir), value_(value) {}
97
DataType() const98 tensorflow::DataType DataType() const override {
99 tensorflow::DataType type;
100 Status s = ConvertToDataType(value_.getType(), &type);
101 if (!s.ok()) {
102 return tensorflow::DT_INVALID;
103 }
104 return type;
105 }
106
Shape(tensorflow::PartialTensorShape * shape) const107 tensorflow::Status Shape(
108 tensorflow::PartialTensorShape* shape) const override {
109 // TODO(b/173074167): Implement this and enable tests in
110 // unified_api_test.cc.
111 return Unimplemented("MlirTensor::Shape is not implemented yet.");
112 }
113
getValue()114 Value getValue() { return value_; }
getElementType()115 Type getElementType() {
116 return value_.getType().cast<ShapedType>().getElementType();
117 }
118
119 // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)120 static bool classof(const AbstractTensorHandle* ptr) {
121 return ptr->getKind() == kMlir;
122 }
123
124 private:
125 Value value_;
126 };
127
128 class MlirFunctionContext;
129
130 class MlirAbstractOp : public TracingOperation {
131 public:
MlirAbstractOp(MLIRContext * context,MlirFunctionContext * function_context)132 explicit MlirAbstractOp(MLIRContext* context,
133 MlirFunctionContext* function_context)
134 : TracingOperation(kMlir),
135 context_(context),
136 function_context_(function_context) {}
137
Release()138 void Release() override { delete this; }
139
140 Status Reset(const char* op, const char* raw_device_name) override;
141
142 const string& Name() const override;
143
144 const string& DeviceName() const override;
145
146 Status SetDeviceName(const char* name) override;
147
148 Status AddInput(AbstractTensorHandle* input) override;
149 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
150 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
151 int* num_retvals) override;
152
153 Status SetAttrString(const char* attr_name, const char* data,
154 size_t length) override;
155 Status SetAttrInt(const char* attr_name, int64_t value) override;
156 Status SetAttrFloat(const char* attr_name, float value) override;
157 Status SetAttrBool(const char* attr_name, bool value) override;
158 Status SetAttrType(const char* attr_name,
159 tensorflow::DataType dtype) override;
160 Status SetAttrShape(const char* attr_name, const int64_t* dims,
161 const int num_dims) override;
162 Status SetAttrFunction(const char* attr_name,
163 const AbstractOperation* value) override;
164 Status SetAttrFunctionName(const char* attr_name, const char* value,
165 size_t length) override;
166 Status SetAttrTensor(const char* attr_name,
167 AbstractTensorInterface* tensor) override;
168 Status SetAttrStringList(const char* attr_name, const void* const* values,
169 const size_t* lengths, int num_values) override;
170 Status SetAttrFloatList(const char* attr_name, const float* values,
171 int num_values) override;
172 Status SetAttrIntList(const char* attr_name, const int64_t* values,
173 int num_values) override;
174 Status SetAttrTypeList(const char* attr_name,
175 const tensorflow::DataType* values,
176 int num_values) override;
177 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
178 int num_values) override;
179 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
180 const int* num_dims, int num_values) override;
181 Status SetAttrFunctionList(
182 const char* attr_name,
183 absl::Span<const AbstractOperation*> values) override;
184
185 Status SetOpName(const char* const op_name) override;
186
GetContext()187 MLIRContext* GetContext() { return context_; }
188
189 Status AddRef(Type type, Type* output_type);
190
191 Status Create(ArrayRef<Value> operands, OperationState**);
192
193 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)194 static bool classof(const AbstractOperation* ptr) {
195 return ptr->getKind() == kMlir;
196 }
197
198 private:
199 // Return true is there are still unfilled ODS slots for adding more inputs.
200 bool IsNextODSArgAvailable();
201
202 MLIRContext* context_;
203 MlirFunctionContext* function_context_;
204 SmallVector<Value, 8> operands_;
205 llvm::StringMap<Attribute> attrs_;
206 std::unique_ptr<OperationState> state_;
207 // This is the index of the next ODS operand that will be added with AddInput
208 // or AddInput;
209 int current_ods_input_ = 0;
210 const tensorflow::OpDef* op_def_ = nullptr;
211 const char* op_name_ = nullptr;
212 string tf_op_type_;
213 // TODO(srbs): Use this.
214 string device_name_;
215 };
216
217 // MlirFunction is a thin wrapper over a FuncOp.
218 class MlirFunction : public AbstractFunction {
219 public:
MlirFunction(std::unique_ptr<MLIRContext> context,OwningOpRef<mlir::ModuleOp> module,func::FuncOp func)220 explicit MlirFunction(std::unique_ptr<MLIRContext> context,
221 OwningOpRef<mlir::ModuleOp> module, func::FuncOp func)
222 : AbstractFunction(kMlir),
223 context_(std::move(context)),
224 module_(std::move(module)),
225 func_(func) {}
226
227 Status GetFunctionDef(tensorflow::FunctionDef** f) override;
228
229 // For LLVM style RTTI.
classof(const AbstractFunction * ptr)230 static bool classof(const AbstractFunction* ptr) {
231 return ptr->getKind() == kMlir;
232 }
233
234 private:
235 std::unique_ptr<MLIRContext> context_;
236 OwningOpRef<mlir::ModuleOp> module_;
237 func::FuncOp func_;
238 std::unique_ptr<tensorflow::FunctionDef> fdef_;
239 };
240
241 class MlirFunctionContext : public TracingContext {
242 public:
MlirFunctionContext(const char * name)243 explicit MlirFunctionContext(const char* name)
244 : TracingContext(kMlir),
245 context_(std::make_unique<MLIRContext>()),
246 builder_(context_.get()) {
247 RegisterDialects(*context_);
248 // TODO(aminim) figure out the location story here
249 module_ = ModuleOp::create(builder_.getUnknownLoc());
250 func_ =
251 func::FuncOp::create(builder_.getUnknownLoc(), name,
252 builder_.getFunctionType(llvm::None, llvm::None));
253 module_->push_back(func_);
254 builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
255 }
256
Release()257 void Release() override { delete this; }
258
CreateOperation()259 AbstractOperation* CreateOperation() override {
260 return new MlirAbstractOp(context_.get(), this);
261 }
262 Status AddParameter(tensorflow::DataType dtype,
263 const tensorflow::PartialTensorShape& shape,
264 TracingTensorHandle** handle) override;
265
266 Status Finalize(OutputList* outputs, AbstractFunction** f) override;
267
RegisterFunction(AbstractFunction * func)268 Status RegisterFunction(AbstractFunction* func) override {
269 return Unimplemented(
270 "Registering graph functions has not been implemented yet.");
271 }
272
RemoveFunction(const string & func)273 Status RemoveFunction(const string& func) override {
274 return Unimplemented(
275 "MlirFunctionContext::RemoveFunction has not been implemented yet.");
276 }
277
278 Operation* CreateOperationFromState(const OperationState& state);
279
280 private:
281 std::unique_ptr<MLIRContext> context_;
282 OpBuilder builder_;
283 func::FuncOp func_;
284 OwningOpRef<mlir::ModuleOp> module_;
285 };
286
Reset(const char * op,const char * device_name)287 Status MlirAbstractOp::Reset(const char* op, const char* device_name) {
288 if (state_) {
289 return FailedPrecondition("Reset called on already built op.");
290 }
291 TF_RETURN_IF_ERROR(
292 tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_));
293 assert(op_def_);
294
295 tf_op_type_ = op;
296 std::string name = "tf.";
297 name += op;
298 // TODO(aminim) figure out the location story here
299 state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
300 return ::tensorflow::OkStatus();
301 }
302
SetAttrType(const char * attr_name,tensorflow::DataType dtype)303 Status MlirAbstractOp::SetAttrType(const char* attr_name,
304 tensorflow::DataType dtype) {
305 if (!state_)
306 return FailedPrecondition(
307 "op_type must be specified before specifying attrs.");
308 Type mlir_type;
309 Builder builder(context_);
310 TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type));
311 attrs_[attr_name] = TypeAttr::get(mlir_type);
312 return ::tensorflow::OkStatus();
313 }
314
SetOpName(const char * const op_name)315 Status MlirAbstractOp::SetOpName(const char* const op_name) {
316 // TODO(aminim): should we use a location?
317 if (op_name_) {
318 return FailedPrecondition("SetOpName called on already built op.");
319 }
320 op_name_ = op_name;
321 return ::tensorflow::OkStatus();
322 }
323
AddRef(Type type,Type * output_type)324 Status MlirAbstractOp::AddRef(Type type, Type* output_type) {
325 Type elt_type = getElementTypeOrSelf(type);
326 if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
327 return InvalidArgument("Requested reference to a reference type");
328 }
329 elt_type = TensorFlowRefType::get(elt_type);
330 if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
331 *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type);
332 }
333 *output_type = UnrankedTensorType::get(elt_type);
334 return ::tensorflow::OkStatus();
335 }
336
Create(ArrayRef<Value> operands,OperationState ** state)337 Status MlirAbstractOp::Create(ArrayRef<Value> operands,
338 OperationState** state) {
339 state_->operands = llvm::to_vector<4>(operands);
340 Builder builder(context_);
341
342 if (current_ods_input_ != op_def_->input_arg_size())
343 return InvalidArgument(absl::StrCat("Mismatch in operands number: got ",
344 current_ods_input_, " expected ",
345 op_def_->input_arg_size(), " ; for op ",
346 state_->name.getStringRef().str()));
347
348 // Process results according to the op_def and infer types for derived
349 // attributes.
350 for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) {
351 int original_size = state_->types.size();
352 if (!output_arg.number_attr().empty()) {
353 // Same type repeated "repeats" times.
354 Attribute repeats_attr = attrs_[output_arg.number_attr()];
355 if (!repeats_attr)
356 return InvalidArgument("Missing attribute '", output_arg.number_attr(),
357 "' required for output list '",
358 output_arg.name(), "'");
359 if (!repeats_attr.isa<IntegerAttr>())
360 return InvalidArgument("Attribute '", output_arg.number_attr(),
361 "' required for output list '",
362 output_arg.name(), "' isn't an integer");
363 int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
364
365 if (!output_arg.type_attr().empty()) {
366 // Same type repeated "repeats" times.
367 Attribute attr = attrs_[output_arg.type_attr()];
368 if (!attr)
369 return InvalidArgument("Missing attribute '", output_arg.type_attr(),
370 "' required for output '", output_arg.name(),
371 "'");
372 TypedAttr type_attr = attr.dyn_cast<TypedAttr>();
373 if (!type_attr)
374 return InvalidArgument("Attribute '", output_arg.type_attr(),
375 "' required for output '", output_arg.name(),
376 "' isn't a type attribute");
377 for (int i = 0; i < repeats; ++i)
378 state_->types.push_back(UnrankedTensorType::get(type_attr.getType()));
379 } else if (output_arg.type() != tensorflow::DT_INVALID) {
380 for (int i = 0; i < repeats; ++i) {
381 Type type;
382 TF_RETURN_IF_ERROR(
383 ConvertDataType(output_arg.type(), builder, &type));
384 state_->types.push_back(type);
385 }
386 } else {
387 return InvalidArgument("Missing type or type_attr field in ",
388 output_arg.ShortDebugString());
389 }
390 } else if (!output_arg.type_attr().empty()) {
391 Attribute attr = attrs_[output_arg.type_attr()];
392 if (!attr)
393 return InvalidArgument("Missing attribute '", output_arg.type_attr(),
394 "' required for output '", output_arg.name(),
395 "'");
396 TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
397 if (!type_attr)
398 return InvalidArgument("Attribute '", output_arg.type_attr(),
399 "' required for output '", output_arg.name(),
400 "' isn't a type attribute");
401 state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
402 } else if (!output_arg.type_list_attr().empty()) {
403 // This is pointing to an attribute which is an array of types.
404 Attribute attr = attrs_[output_arg.type_list_attr()];
405 if (!attr)
406 return InvalidArgument(
407 "Missing attribute '", output_arg.type_list_attr(),
408 "' required for output '", output_arg.name(), "'");
409 ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
410 if (!array_attr)
411 return InvalidArgument("Attribute '", output_arg.type_list_attr(),
412 "' required for output '", output_arg.name(),
413 "' isn't an array attribute");
414 for (Attribute attr : array_attr) {
415 TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
416 if (!type_attr)
417 return InvalidArgument("Array Attribute '",
418 output_arg.type_list_attr(),
419 "' required for output '", output_arg.name(),
420 "' has a non-Type element");
421 state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
422 }
423 } else if (output_arg.type() != tensorflow::DT_INVALID) {
424 Type type;
425 Builder builder(context_);
426 TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type));
427 state_->types.push_back(type);
428 } else {
429 return InvalidArgument("No type fields in ",
430 output_arg.ShortDebugString());
431 }
432 if (output_arg.is_ref()) {
433 // For all types that were added by this function call, make them refs.
434 for (Type& type : llvm::make_range(&state_->types[original_size],
435 state_->types.end())) {
436 Type output_type;
437 TF_RETURN_IF_ERROR(AddRef(type, &output_type));
438 type = output_type;
439 }
440 }
441 }
442 for (auto& it : attrs_) state_->addAttribute(it.first(), it.second);
443 *state = state_.get();
444 return ::tensorflow::OkStatus();
445 }
446
Name() const447 const string& MlirAbstractOp::Name() const { return tf_op_type_; }
448
DeviceName() const449 const string& MlirAbstractOp::DeviceName() const { return device_name_; }
450
SetDeviceName(const char * name)451 Status MlirAbstractOp::SetDeviceName(const char* name) {
452 device_name_ = name;
453 return ::tensorflow::OkStatus();
454 }
455
SetAttrString(const char * attr_name,const char * data,size_t length)456 Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
457 size_t length) {
458 return Unimplemented("SetAttrString has not been implemented yet.");
459 }
SetAttrInt(const char * attr_name,int64_t value)460 Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) {
461 return Unimplemented("SetAttrInt has not been implemented yet.");
462 }
SetAttrFloat(const char * attr_name,float value)463 Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
464 return Unimplemented("SetAttrFloat has not been implemented yet.");
465 }
SetAttrBool(const char * attr_name,bool value)466 Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
467 attrs_[attr_name] = BoolAttr::get(context_, value);
468 return ::tensorflow::OkStatus();
469 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)470 Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
471 const int num_dims) {
472 return Unimplemented("SetAttrShape has not been implemented yet.");
473 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)474 Status MlirAbstractOp::SetAttrFunction(const char* attr_name,
475 const AbstractOperation* value) {
476 return Unimplemented("SetAttrFunction has not been implemented yet.");
477 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)478 Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name,
479 const char* value, size_t length) {
480 return Unimplemented("SetAttrFunctionName has not been implemented yet.");
481 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)482 Status MlirAbstractOp::SetAttrTensor(const char* attr_name,
483 AbstractTensorInterface* tensor) {
484 return Unimplemented("SetAttrTensor has not been implemented yet.");
485 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)486 Status MlirAbstractOp::SetAttrStringList(const char* attr_name,
487 const void* const* values,
488 const size_t* lengths,
489 int num_values) {
490 return Unimplemented("SetAttrStringList has not been implemented yet.");
491 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)492 Status MlirAbstractOp::SetAttrFloatList(const char* attr_name,
493 const float* values, int num_values) {
494 return Unimplemented("SetAttrFloatList has not been implemented yet.");
495 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)496 Status MlirAbstractOp::SetAttrIntList(const char* attr_name,
497 const int64_t* values, int num_values) {
498 return Unimplemented("SetAttrIntList has not been implemented yet.");
499 }
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)500 Status MlirAbstractOp::SetAttrTypeList(const char* attr_name,
501 const tensorflow::DataType* values,
502 int num_values) {
503 return Unimplemented("SetAttrTypeList has not been implemented yet.");
504 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)505 Status MlirAbstractOp::SetAttrBoolList(const char* attr_name,
506 const unsigned char* values,
507 int num_values) {
508 return Unimplemented("SetAttrBoolList has not been implemented yet.");
509 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)510 Status MlirAbstractOp::SetAttrShapeList(const char* attr_name,
511 const int64_t** dims,
512 const int* num_dims, int num_values) {
513 return Unimplemented("SetAttrShapeList has not been implemented yet.");
514 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)515 Status MlirAbstractOp::SetAttrFunctionList(
516 const char* attr_name, absl::Span<const AbstractOperation*> values) {
517 return Unimplemented("SetAttrFunctionList has not been implemented yet.");
518 }
519
GetFunctionDef(tensorflow::FunctionDef ** f)520 Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
521 if (fdef_) {
522 *f = fdef_.get();
523 return ::tensorflow::OkStatus();
524 }
525 PassManager pm(func_.getContext());
526 ::tensorflow::applyTensorflowAndCLOptions(pm);
527 pm.addNestedPass<func::FuncOp>(
528 CreateFunctionalToExecutorDialectConversionPass());
529 pm.addPass(CreateBreakUpIslandsPass());
530
531 // In case of failure, the `diag_handler` converts MLIR errors emitted to
532 // the MLIRContext into a tensorflow::Status.
533 StatusScopedDiagnosticHandler diag_handler(func_.getContext());
534 LogicalResult result = pm.run(func_->getParentOfType<ModuleOp>());
535 (void)result;
536 TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
537
538 tensorflow::GraphExportConfig configs;
539 fdef_.reset(new tensorflow::FunctionDef());
540 TF_RETURN_IF_ERROR(
541 ConvertMlirFunctionToFunctionLibraryDef(func_, configs, fdef_.get()));
542 *f = fdef_.get();
543 return ::tensorflow::OkStatus();
544 }
545
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)546 Status MlirAbstractOp::Execute(absl::Span<AbstractTensorHandle*> retvals,
547 int* num_retvals) {
548 OperationState* state;
549 TF_RETURN_IF_ERROR(Create(operands_, &state));
550 Operation* op = function_context_->CreateOperationFromState(*state);
551 *num_retvals = op->getNumResults();
552 for (int i = 0; i < *num_retvals; i++)
553 retvals[i] = new MlirTensor(op->getResult(i));
554 return ::tensorflow::OkStatus();
555 }
556
CreateOperationFromState(const OperationState & state)557 Operation* MlirFunctionContext::CreateOperationFromState(
558 const OperationState& state) {
559 return builder_.create(state);
560 }
561
AddParameter(tensorflow::DataType dtype,const tensorflow::PartialTensorShape & shape,TracingTensorHandle ** handle)562 Status MlirFunctionContext::AddParameter(
563 tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
564 TracingTensorHandle** handle) {
565 // TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
566 // resolved.
567 Type type;
568 TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
569 *handle =
570 new MlirTensor(func_.getBody().front().addArgument(type, func_.getLoc()));
571 return ::tensorflow::OkStatus();
572 }
573
AddInput(AbstractTensorHandle * input)574 Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) {
575 if (current_ods_input_ >= op_def_->input_arg_size())
576 return InvalidArgument(
577 absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
578 op_def_->input_arg_size(), " allowed input_args ; for op ",
579 state_->name.getStringRef().str()));
580
581 auto* operand = dyn_cast<MlirTensor>(input);
582 if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
583 operands_.push_back(operand->getValue());
584
585 // Get the next ArgDef and use it to infer the derived attributes associated
586 // to this input.
587 const tensorflow::OpDef::ArgDef& arg_def =
588 op_def_->input_arg(current_ods_input_++);
589 Type expected_type;
590 if (arg_def.type() != tensorflow::DT_INVALID) {
591 Builder builder(context_);
592 TF_RETURN_IF_ERROR(
593 tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type));
594 if (arg_def.is_ref()) {
595 Type output_type;
596 TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type));
597 expected_type = output_type;
598 }
599 } else {
600 expected_type = cast<MlirTensor>(input)->getElementType();
601 }
602 if (!arg_def.type_attr().empty())
603 attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
604
605 return ::tensorflow::OkStatus();
606 }
607
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)608 Status MlirAbstractOp::AddInputList(
609 absl::Span<AbstractTensorHandle* const> inputs) {
610 if (current_ods_input_ >= op_def_->input_arg_size())
611 return InvalidArgument(
612 absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
613 op_def_->input_arg_size(), " allowed input_args"));
614
615 for (AbstractTensorHandle* input : inputs) {
616 auto* operand = dyn_cast<MlirTensor>(input);
617 if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
618 operands_.push_back(operand->getValue());
619 }
620
621 // Get the next ArgDef and use it to infer the derived attributes associated
622 // to this input.
623 const tensorflow::OpDef::ArgDef& arg_def =
624 op_def_->input_arg(current_ods_input_++);
625 if (!arg_def.number_attr().empty()) {
626 Builder builder(context_);
627 attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size());
628 // TODO(aminim): handle ref variable.
629 if (arg_def.type() != tensorflow::DT_INVALID) {
630 // TODO(aminim): check type wrt input
631 Type arg_def_type;
632 TF_RETURN_IF_ERROR(
633 ConvertDataType(arg_def.type(), builder, &arg_def_type));
634 // Ensure each of the type in the list matches the op def type.
635 // TODO(aminim): can we improve the error message with the actual types?
636 for (AbstractTensorHandle* input : inputs)
637 if (arg_def_type != cast<MlirTensor>(input)->getElementType())
638 return InvalidArgument(
639 "Invalid input list: type mismatch the op def expectation");
640 } else if (!inputs.empty()) {
641 if (arg_def.type_attr().empty())
642 return FailedPrecondition(
643 "Invalid opdef type constraint: either type or type_attr required");
644
645 attrs_[arg_def.type_attr()] =
646 TypeAttr::get(cast<MlirTensor>(inputs.front())->getElementType());
647 }
648 } else if (!arg_def.type_list_attr().empty()) {
649 // TODO(aminim): handle ref variable.
650 SmallVector<Attribute, 8> types;
651 types.reserve(inputs.size());
652 for (AbstractTensorHandle* input : inputs)
653 types.push_back(TypeAttr::get(cast<MlirTensor>(input)->getElementType()));
654 attrs_[arg_def.type_list_attr()] = ArrayAttr::get(GetContext(), types);
655 }
656 return ::tensorflow::OkStatus();
657 }
658
Finalize(OutputList * outputs,AbstractFunction ** f)659 Status MlirFunctionContext::Finalize(OutputList* outputs,
660 AbstractFunction** f) {
661 Block& body = func_.getBody().front();
662 SmallVector<Value, 8> ret_operands;
663 for (auto* output : outputs->outputs) {
664 auto* operand = dyn_cast<MlirTensor>(output);
665 if (!operand)
666 return InvalidArgument("Capturing eager tensors is not supported yet.");
667 if (operand->getValue().getContext() != context_.get())
668 return InvalidArgument(
669 "Capturing tensors from other context is not supported.");
670 ret_operands.push_back(operand->getValue());
671 }
672 builder_.create<func::ReturnOp>(func_.getLoc(), ret_operands);
673
674 auto arg_types = body.getArgumentTypes();
675 auto result_types = body.getTerminator()->getOperandTypes();
676 func_.setType(FunctionType::get(func_.getContext(), arg_types, result_types));
677 *f = new MlirFunction(std::move(context_), std::move(module_), func_);
678 return ::tensorflow::OkStatus();
679 }
680
681 extern "C" {
MlirTracingFactory(const char * fn_name,TF_Status * s)682 TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
683 return new MlirFunctionContext(fn_name);
684 }
685 }
686
687 } // namespace
688 } // namespace TF
689 } // namespace mlir
690