xref: /aosp_15_r20/external/tensorflow/tensorflow/python/pywrap_mlir.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python module for MLIR functions exported by pybind11."""
16
17# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
18from tensorflow.python import pywrap_tensorflow
19from tensorflow.python.eager import context
20from tensorflow.python._pywrap_mlir import *
21
22
23def import_graphdef(graphdef,
24                    pass_pipeline,
25                    show_debug_info,
26                    input_names=None,
27                    input_data_types=None,
28                    input_data_shapes=None,
29                    output_names=[]):
30  if input_names is not None:
31    return ImportGraphDef(
32        str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'),
33        show_debug_info, ','.join(input_names).encode('utf-8'),
34        ','.join(input_data_types).encode('utf-8'),
35        ':'.join(input_data_shapes).encode('utf-8'),
36        ','.join(output_names).encode('utf-8'))
37  return ImportGraphDef(
38      str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'),
39      show_debug_info)
40
41
42def import_function(concrete_function, pass_pipeline, show_debug_info):
43  ctxt = context.context()
44  ctxt.ensure_initialized()
45  return ImportFunction(ctxt._handle,
46                        str(concrete_function.function_def).encode('utf-8'),
47                        pass_pipeline.encode('utf-8'), show_debug_info)
48
49
50def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
51                                             show_debug_info):
52  return ExperimentalConvertSavedModelToMlir(
53      str(saved_model_path).encode('utf-8'),
54      str(exported_names).encode('utf-8'), show_debug_info)
55
56
57def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path,
58                                                     exported_names, tags,
59                                                     upgrade_legacy,
60                                                     show_debug_info):
61  return ExperimentalConvertSavedModelV1ToMlirLite(
62      str(saved_model_path).encode('utf-8'),
63      str(exported_names).encode('utf-8'),
64      str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
65
66
67def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
68                                                exported_names, tags,
69                                                lift_variables, upgrade_legacy,
70                                                show_debug_info):
71  return ExperimentalConvertSavedModelV1ToMlir(
72      str(saved_model_path).encode('utf-8'),
73      str(exported_names).encode('utf-8'),
74      str(tags).encode('utf-8'), lift_variables, upgrade_legacy,
75      show_debug_info)
76
77
78def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
79  return ExperimentalRunPassPipeline(
80      mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info)
81