1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Standalone utility to generate some test saved models.""" 16 17from absl import app 18 19from tensorflow.python.checkpoint import checkpoint 20from tensorflow.python.compat import v2_compat 21from tensorflow.python.framework import dtypes 22from tensorflow.python.module import module 23from tensorflow.python.ops import lookup_ops 24from tensorflow.python.ops import variables 25 26 27class TableModule(module.Module): 28 """Three vars (one in a sub-module) and compute method.""" 29 30 def __init__(self): 31 default_value = -1 32 empty_key = 0 33 deleted_key = -1 34 self.lookup_table = lookup_ops.DenseHashTable( 35 dtypes.int64, 36 dtypes.int64, 37 default_value=default_value, 38 empty_key=empty_key, 39 deleted_key=deleted_key, 40 name="t1", 41 initial_num_buckets=32) 42 self.lookup_table.insert(1, 1) 43 self.lookup_table.insert(2, 4) 44 45 46class VariableModule(module.Module): 47 48 def __init__(self): 49 self.v = variables.Variable([1., 2., 3.]) 50 self.w = variables.Variable([4., 5.]) 51 52MODULE_CTORS = { 53 "TableModule": TableModule, 54 "VariableModule": VariableModule, 55} 56 57 58def main(args): 59 if len(args) != 3: 60 print("Expected: {export_path} {ModuleName}") 61 print("Allowed ModuleNames:", MODULE_CTORS.keys()) 62 return 1 63 64 _, export_path, module_name = args 65 module_ctor = MODULE_CTORS.get(module_name) 66 if not module_ctor: 67 print("Expected ModuleName to be one of:", MODULE_CTORS.keys()) 68 return 2 69 70 tf_module = module_ctor() 71 ckpt = checkpoint.Checkpoint(tf_module) 72 ckpt.write(export_path) 73 74 75if __name__ == "__main__": 76 v2_compat.enable_v2_behavior() 77 app.run(main) 78