xref: /aosp_15_r20/external/pytorch/torch/backends/_coreml/preprocess.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import hashlib
3import json
4from typing import Dict, Tuple
5
6import coremltools as ct  # type: ignore[import]
7from coremltools.converters.mil.input_types import TensorType  # type: ignore[import]
8from coremltools.converters.mil.mil import types  # type: ignore[import]
9from coremltools.models.neural_network import quantization_utils  # type: ignore[import]
10
11import torch
12
13
14CT_METADATA_VERSION = "com.github.apple.coremltools.version"
15CT_METADATA_SOURCE = "com.github.apple.coremltools.source"
16
17
18class ScalarType:
19    Float = 0
20    Double = 1
21    Int = 2
22    Long = 3
23    Undefined = 4
24
25
26# Supported Tensor types in coremltools:
27# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28
28torch_to_mil_types = {
29    ScalarType.Float: types.fp32,
30    ScalarType.Double: types.fp64,
31    ScalarType.Int: types.int32,
32    ScalarType.Long: types.int64,
33}
34
35
36class CoreMLComputeUnit:
37    CPU = "cpuOnly"
38    CPUAndGPU = "cpuAndGPU"
39    ALL = "all"
40
41
42class CoreMLQuantizationMode:
43    LINEAR = "linear"
44    LINEAR_SYMMETRIC = "linear_symmetric"
45    NONE = "none"
46
47
48def TensorSpec(shape, dtype=ScalarType.Float):
49    return (shape, dtype)
50
51
52def CompileSpec(
53    inputs,
54    outputs,
55    backend=CoreMLComputeUnit.CPU,
56    allow_low_precision=True,
57    quantization_mode=CoreMLQuantizationMode.NONE,
58    mlmodel_export_path=None,
59):
60    return (
61        inputs,
62        outputs,
63        backend,
64        allow_low_precision,
65        quantization_mode,
66        mlmodel_export_path,
67    )
68
69
70def _check_enumerated_shape(shape):
71    for s in shape:
72        if not isinstance(s, (list, tuple)):
73            return False
74    return True
75
76
77def _convert_to_mil_type(shape, dtype, name: str):
78    mil_shape = shape
79    if _check_enumerated_shape(shape):
80        mil_shape = ct.EnumeratedShapes(shape)
81    ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype])
82    ml_type.name = name
83    return ml_type
84
85
86def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]):
87    spec = compile_spec["forward"]
88    (
89        input_specs,
90        output_specs,
91        backend,
92        allow_low_precision,
93        quantization_mode,
94        mlmodel_export_path,
95    ) = spec
96    mil_inputs = []
97    inputs = []
98    for index, input in enumerate(input_specs):
99        shape, dtype = input
100        name = "input_" + str(index)
101        inputs.append([name, str(dtype), str(shape)])
102        ml_type = _convert_to_mil_type(shape, dtype, name)
103        mil_inputs.append(ml_type)
104    model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None)
105    mlmodel = ct.convert(model, inputs=mil_inputs)
106
107    if quantization_mode != CoreMLQuantizationMode.NONE:
108        quant_model_spec = quantization_utils.quantize_weights(
109            mlmodel, nbits=8, quantization_mode=quantization_mode
110        )
111        mlmodel = ct.models.MLModel(quant_model_spec)
112
113    spec = mlmodel.get_spec()
114    assert len(spec.description.output) == len(output_specs)  # type: ignore[attr-defined]
115    outputs = []
116    for index, output in enumerate(output_specs):
117        shape, dtype = output
118        name = spec.description.output[index].name  # type: ignore[attr-defined]
119        outputs.append([name, str(dtype), str(shape)])
120    mlmodel = ct.models.model.MLModel(spec)
121    print(mlmodel)
122
123    if mlmodel_export_path is not None:
124        print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}")
125        mlmodel.save(mlmodel_export_path)
126
127    config = {
128        "spec_ver": str(spec.specificationVersion),  # type: ignore[attr-defined]
129        "backend": backend,
130        "allow_low_precision": str(allow_low_precision),
131    }
132    metadata = {
133        "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION],
134        "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE],
135    }
136    coreml_compile_spec = {
137        "inputs": inputs,
138        "outputs": outputs,
139        "config": config,
140        "metadata": metadata,
141    }
142    mlmodel = spec.SerializeToString()  # type: ignore[attr-defined]
143
144    return {
145        "model": mlmodel,
146        "hash": str(hashlib.sha256(mlmodel).hexdigest()),
147        "extra": json.dumps(coreml_compile_spec),
148    }
149