xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h"
17 
18 #include "absl/strings/str_split.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
22 #include "mlir/Pass/PassManager.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
26 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
31 #include "tfrt/bef_converter/mlir_to_bef.h"  // from @tf_runtime
32 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
33 #include "tfrt/core_runtime/op_handler.h"  // from @tf_runtime
34 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
35 #include "tfrt/tensor/dense_host_tensor_view.h"  // from @tf_runtime
36 
37 namespace tensorflow {
38 namespace {
39 
ProcessIndexPath(mlir::ArrayAttr index_path)40 llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) {
41   if (index_path.size() == 1 && index_path[0].isa<mlir::StringAttr>()) {
42     // TODO(chky): Support cases where index_path is not a single string.
43     return index_path[0].cast<mlir::StringAttr>().getValue();
44   }
45   return "";
46 }
47 
48 StatusOr<std::pair<tensorflow::DataType, tensorflow::PartialTensorShape>>
ProcessTensorSpec(mlir::TensorType type)49 ProcessTensorSpec(mlir::TensorType type) {
50   tensorflow::DataType dtype;
51   TF_RETURN_IF_ERROR(
52       ConvertScalarTypeToDataType(type.getElementType(), &dtype));
53 
54   if (!type.hasRank())
55     return std::make_pair(dtype, tensorflow::PartialTensorShape());
56 
57   auto shape = type.getShape();
58   llvm::SmallVector<int64_t, 4> dims;
59   dims.assign(shape.begin(), shape.end());
60   return std::make_pair(dtype, tensorflow::PartialTensorShape(dims));
61 }
62 
63 }  // namespace
64 
MapFunctionSignaturesFromTFSavedModelMLIR(mlir::ModuleOp module,llvm::function_ref<void (const TFRTSavedModelSignatureInfo &)> map_fn)65 Status MapFunctionSignaturesFromTFSavedModelMLIR(
66     mlir::ModuleOp module,
67     llvm::function_ref<void(const TFRTSavedModelSignatureInfo&)> map_fn) {
68   // Create bound inputs for each functions.
69   mlir::SymbolTable symbol_table(module);
70   tensorflow::Status status = OkStatus();
71   module.walk([&symbol_table, map_fn, &status](mlir::func::FuncOp func) {
72     // Use the exported name as the function name, and skip non-exported
73     // functions.
74     auto func_names = mlir::tf_saved_model::GetExportedNames(func);
75     if (func_names.empty()) return mlir::WalkResult::advance();
76 
77     auto func_type = func.getFunctionType();
78 
79     // Here we walk through each arguments and find out the input/output names,
80     // and input devices, variables used by this function.
81     llvm::SmallVector<llvm::StringRef, 4> input_names;
82     llvm::SmallVector<
83         std::pair<tensorflow::DataType, tensorflow::PartialTensorShape>, 4>
84         input_specs;
85     llvm::SmallVector<llvm::StringRef, 4> input_devices;
86     llvm::SmallVector<mlir::Operation*, 4> bound_inputs;
87     for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) {
88       if (auto input_index_path = func.getArgAttrOfType<mlir::ArrayAttr>(
89               i, "tf_saved_model.index_path")) {
90         input_names.push_back(ProcessIndexPath(input_index_path));
91         auto statusor_spec =
92             ProcessTensorSpec(func_type.getInput(i).cast<mlir::TensorType>());
93         if (!statusor_spec.ok()) {
94           status = std::move(statusor_spec).status();
95           return mlir::WalkResult::interrupt();
96         }
97         input_specs.push_back(std::move(statusor_spec).ValueOrDie());
98         if (auto input_device =
99                 func.getArgAttrOfType<mlir::StringAttr>(i, "tf.device")) {
100           input_devices.push_back(input_device.getValue());
101         } else {
102           input_devices.push_back("");
103         }
104       }
105       if (auto* bound_input =
106               mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) {
107         bound_inputs.push_back(bound_input);
108       }
109     }
110 
111     llvm::SmallVector<llvm::StringRef, 4> output_names;
112     llvm::SmallVector<
113         std::pair<tensorflow::DataType, tensorflow::PartialTensorShape>, 4>
114         output_specs;
115     for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) {
116       if (auto output_index_path = func.getResultAttrOfType<mlir::ArrayAttr>(
117               i, "tf_saved_model.index_path")) {
118         output_names.push_back(ProcessIndexPath(output_index_path));
119         auto statusor_spec =
120             ProcessTensorSpec(func_type.getResult(i).cast<mlir::TensorType>());
121         if (!statusor_spec.ok()) {
122           status = std::move(statusor_spec).status();
123           return mlir::WalkResult::interrupt();
124         }
125         output_specs.push_back(std::move(statusor_spec).ValueOrDie());
126       }
127     }
128 
129     for (auto func_name : func_names) {
130       TFRTSavedModelSignatureInfo sig_info;
131       sig_info.func_name = func_name;
132       sig_info.input_names = input_names;
133       sig_info.input_specs = input_specs;
134       sig_info.input_devices = input_devices;
135       sig_info.output_names = output_names;
136       sig_info.output_specs = output_specs;
137       sig_info.bound_inputs = bound_inputs;
138       map_fn(sig_info);
139     }
140 
141     return mlir::WalkResult::advance();
142   });
143 
144   return status;
145 }
146 
147 }  // namespace tensorflow
148