1from torchvision import models 2 3import torch 4from torch.backends._coreml.preprocess import CompileSpec, CoreMLComputeUnit, TensorSpec 5 6 7def mobilenetv2_spec(): 8 return { 9 "forward": CompileSpec( 10 inputs=(TensorSpec(shape=[1, 3, 224, 224]),), 11 outputs=(TensorSpec(shape=[1, 1000]),), 12 backend=CoreMLComputeUnit.CPU, 13 allow_low_precision=True, 14 ), 15 } 16 17 18def main(): 19 model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) 20 model.eval() 21 example = torch.rand(1, 3, 224, 224) 22 model = torch.jit.trace(model, example) 23 compile_spec = mobilenetv2_spec() 24 mlmodel = torch._C._jit_to_backend("coreml", model, compile_spec) 25 print(mlmodel._c._get_method("forward").graph) 26 mlmodel._save_for_lite_interpreter("../models/model_coreml.ptl") 27 torch.jit.save(mlmodel, "../models/model_coreml.pt") 28 29 30if __name__ == "__main__": 31 main() 32