from pathlib import Path import torch from torch.fx import symbolic_trace from torch.package import PackageExporter from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE packaging_directory = f"{Path(__file__).parent}/package_bc" torch.package.package_exporter._gate_torchscript_serialization = False def generate_bc_packages(): """Function to create packages for testing backwards compatiblity""" if not IS_FBCODE or IS_SANDCASTLE: from package_a.test_nn_module import TestNnModule test_nn_module = TestNnModule() test_torchscript_module = torch.jit.script(TestNnModule()) test_fx_module: torch.fx.GraphModule = symbolic_trace(TestNnModule()) with PackageExporter(f"{packaging_directory}/test_nn_module.pt") as pe1: pe1.intern("**") pe1.save_pickle("nn_module", "nn_module.pkl", test_nn_module) with PackageExporter( f"{packaging_directory}/test_torchscript_module.pt" ) as pe2: pe2.intern("**") pe2.save_pickle( "torchscript_module", "torchscript_module.pkl", test_torchscript_module ) with PackageExporter(f"{packaging_directory}/test_fx_module.pt") as pe3: pe3.intern("**") pe3.save_pickle("fx_module", "fx_module.pkl", test_fx_module) if __name__ == "__main__": generate_bc_packages()