1# mypy: allow-untyped-defs 2import hashlib 3import json 4from typing import Dict, Tuple 5 6import coremltools as ct # type: ignore[import] 7from coremltools.converters.mil.input_types import TensorType # type: ignore[import] 8from coremltools.converters.mil.mil import types # type: ignore[import] 9from coremltools.models.neural_network import quantization_utils # type: ignore[import] 10 11import torch 12 13 14CT_METADATA_VERSION = "com.github.apple.coremltools.version" 15CT_METADATA_SOURCE = "com.github.apple.coremltools.source" 16 17 18class ScalarType: 19 Float = 0 20 Double = 1 21 Int = 2 22 Long = 3 23 Undefined = 4 24 25 26# Supported Tensor types in coremltools: 27# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28 28torch_to_mil_types = { 29 ScalarType.Float: types.fp32, 30 ScalarType.Double: types.fp64, 31 ScalarType.Int: types.int32, 32 ScalarType.Long: types.int64, 33} 34 35 36class CoreMLComputeUnit: 37 CPU = "cpuOnly" 38 CPUAndGPU = "cpuAndGPU" 39 ALL = "all" 40 41 42class CoreMLQuantizationMode: 43 LINEAR = "linear" 44 LINEAR_SYMMETRIC = "linear_symmetric" 45 NONE = "none" 46 47 48def TensorSpec(shape, dtype=ScalarType.Float): 49 return (shape, dtype) 50 51 52def CompileSpec( 53 inputs, 54 outputs, 55 backend=CoreMLComputeUnit.CPU, 56 allow_low_precision=True, 57 quantization_mode=CoreMLQuantizationMode.NONE, 58 mlmodel_export_path=None, 59): 60 return ( 61 inputs, 62 outputs, 63 backend, 64 allow_low_precision, 65 quantization_mode, 66 mlmodel_export_path, 67 ) 68 69 70def _check_enumerated_shape(shape): 71 for s in shape: 72 if not isinstance(s, (list, tuple)): 73 return False 74 return True 75 76 77def _convert_to_mil_type(shape, dtype, name: str): 78 mil_shape = shape 79 if _check_enumerated_shape(shape): 80 mil_shape = ct.EnumeratedShapes(shape) 81 ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype]) 82 ml_type.name = name 83 return ml_type 84 85 86def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]): 87 spec = compile_spec["forward"] 88 ( 89 input_specs, 90 output_specs, 91 backend, 92 allow_low_precision, 93 quantization_mode, 94 mlmodel_export_path, 95 ) = spec 96 mil_inputs = [] 97 inputs = [] 98 for index, input in enumerate(input_specs): 99 shape, dtype = input 100 name = "input_" + str(index) 101 inputs.append([name, str(dtype), str(shape)]) 102 ml_type = _convert_to_mil_type(shape, dtype, name) 103 mil_inputs.append(ml_type) 104 model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None) 105 mlmodel = ct.convert(model, inputs=mil_inputs) 106 107 if quantization_mode != CoreMLQuantizationMode.NONE: 108 quant_model_spec = quantization_utils.quantize_weights( 109 mlmodel, nbits=8, quantization_mode=quantization_mode 110 ) 111 mlmodel = ct.models.MLModel(quant_model_spec) 112 113 spec = mlmodel.get_spec() 114 assert len(spec.description.output) == len(output_specs) # type: ignore[attr-defined] 115 outputs = [] 116 for index, output in enumerate(output_specs): 117 shape, dtype = output 118 name = spec.description.output[index].name # type: ignore[attr-defined] 119 outputs.append([name, str(dtype), str(shape)]) 120 mlmodel = ct.models.model.MLModel(spec) 121 print(mlmodel) 122 123 if mlmodel_export_path is not None: 124 print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}") 125 mlmodel.save(mlmodel_export_path) 126 127 config = { 128 "spec_ver": str(spec.specificationVersion), # type: ignore[attr-defined] 129 "backend": backend, 130 "allow_low_precision": str(allow_low_precision), 131 } 132 metadata = { 133 "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION], 134 "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE], 135 } 136 coreml_compile_spec = { 137 "inputs": inputs, 138 "outputs": outputs, 139 "config": config, 140 "metadata": metadata, 141 } 142 mlmodel = spec.SerializeToString() # type: ignore[attr-defined] 143 144 return { 145 "model": mlmodel, 146 "hash": str(hashlib.sha256(mlmodel).hexdigest()), 147 "extra": json.dumps(coreml_compile_spec), 148 } 149