xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/testdata/generate_checkpoint.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Standalone utility to generate some test saved models."""
16
17from absl import app
18
19from tensorflow.python.checkpoint import checkpoint
20from tensorflow.python.compat import v2_compat
21from tensorflow.python.framework import dtypes
22from tensorflow.python.module import module
23from tensorflow.python.ops import lookup_ops
24from tensorflow.python.ops import variables
25
26
27class TableModule(module.Module):
28  """Three vars (one in a sub-module) and compute method."""
29
30  def __init__(self):
31    default_value = -1
32    empty_key = 0
33    deleted_key = -1
34    self.lookup_table = lookup_ops.DenseHashTable(
35        dtypes.int64,
36        dtypes.int64,
37        default_value=default_value,
38        empty_key=empty_key,
39        deleted_key=deleted_key,
40        name="t1",
41        initial_num_buckets=32)
42    self.lookup_table.insert(1, 1)
43    self.lookup_table.insert(2, 4)
44
45
46class VariableModule(module.Module):
47
48  def __init__(self):
49    self.v = variables.Variable([1., 2., 3.])
50    self.w = variables.Variable([4., 5.])
51
52MODULE_CTORS = {
53    "TableModule": TableModule,
54    "VariableModule": VariableModule,
55}
56
57
58def main(args):
59  if len(args) != 3:
60    print("Expected: {export_path} {ModuleName}")
61    print("Allowed ModuleNames:", MODULE_CTORS.keys())
62    return 1
63
64  _, export_path, module_name = args
65  module_ctor = MODULE_CTORS.get(module_name)
66  if not module_ctor:
67    print("Expected ModuleName to be one of:", MODULE_CTORS.keys())
68    return 2
69
70  tf_module = module_ctor()
71  ckpt = checkpoint.Checkpoint(tf_module)
72  ckpt.write(export_path)
73
74
75if __name__ == "__main__":
76  v2_compat.enable_v2_behavior()
77  app.run(main)
78