1# Summary 2This example demonstrates how to run [Llama models](https://www.llama.com/) on mobile via ExecuTorch. We use XNNPACK to accelerate the performance and 4-bit groupwise quantization to fit the model on a phone. 3 4Here are supported models: 5 6- Llama 3.2 1B and 3B 7- Llama 3.2 Quantized 1B and 3B 8- Llama 3.1 8B 9- Llama 3 8B 10- [Llama 2 7B](../llama2/README.md) 11 12Pretrained models are not included in this repo. Users are suggested to download them [here](https://ai.meta.com/resources/models-and-libraries/llama-downloads/). 13 14This page contains the basic recipe for running Llama. See [Llama utils page](./UTILS.md) page for more advanced use-cases such as fine-tuning and running smaller models for educational purposes. 15 16# What is Llama? 17Llama is a collection of large language models that use publicly available data for training. These models are based on the transformer architecture, which allows it to process input sequences of arbitrary length and generate output sequences of variable length. One of the key features of Llama models is its ability to generate coherent and contextually relevant text. This is achieved through the use of attention mechanisms, which allow the model to focus on different parts of the input sequence as it generates output. Additionally, Llama models use a technique called “masked language modeling” to pre-train the model on a large corpus of text, which helps it learn to predict missing words in a sentence. 18 19Llama models have shown to perform well on a variety of natural language processing tasks, including language translation, question answering, and text summarization and are also capable of generating human-like text, making Llama models a useful tool for creative writing and other applications where natural language generation is important. 20 21Overall, Llama models are powerful and versatile language models that can be used for a wide range of natural language processing tasks. The model’s ability to generate coherent and contextually relevant text makes it particularly useful for applications such as chatbots, virtual assistants, and language translation. 22 23Please note that the models are subject to the [Llama 2 Acceptable Use Policy](https://github.com/facebookresearch/llama/blob/main/USE_POLICY.md), [Llama 3 Acceptable Use Policy](https://github.com/meta-llama/llama3/blob/main/USE_POLICY.md) and [Responsible Use Guide](https://ai.meta.com/static-resource/responsible-use-guide/). 24 25 26# Results 27 28## Llama 3.2 1B/3B and quantized 1B/3B models 29 30For Llama 3.2 1B/3B models, we have enabled the original BF16 format and quantization to 4-bit, using SpinQuant and QAT+LoRA, for enhanced performance. 31 32The quantized models were optimized primarily for Arm CPU architecture by leveraging XNNPACK and Kleidi AI library. Work is underway to specifically enable quantization on mobile accelerators for Llama 1B/3B. 33 34### Enablement 35 36We have successfully verified performance on the following devices: iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+, S22 and OnePlus 12 (featuring 16GB RAM). 37 38Note, the Llama 3.2 3B unquantized BF16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements. 39 40### Quantization 41 42The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization, using [SpinQuant](https://github.com/facebookresearch/SpinQuant/tree/main) and QAT+LoRA methods. 43 44Our quantization scheme involves three parts, applicable to both methods: 45 46- We quantize all linear layers in all transformer blocks to a 4-bit groupwise scheme (with a group size of 32) for weights and 8-bit per-token dynamic quantization for activations. 47- The classification layer is quantized to 8-bit per-channel for weight and 8-bit per token dynamic quantization for activation. 48- We employ an 8-bit per channel quantization for embedding. 49 50We use [torchao](https://github.com/pytorch/ao) library APIs to define these schemes. 51 52#### SpinQuant 53 54The SpinQuant method takes the original weights and produces optimized quantized weights with minimal outliers, resulting in higher accuracy. This can be achieved without any finetuning of the weights and only requires 100 iterations on a single A100 node. 55 56SpinQuant can generate quantized weights that are [compatible with ExecuTorch](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch), specifically, it can be integrated with the existing optimized XNNPACK kernels (e.g., group-wise 4bit weight and 8bit dynamic activation). This allows developers to benefit from the higher accuracy of SpinQuant while also taking advantage of the strong performance of ExecuTorch acceleration. 57 58#### Quantization-Aware Training and LoRA (QAT+LoRA) 59 60Quantization-Aware Training (QAT) is employed to simulate the effects of quantization during the training of Llama-3.2 models, enabling optimization of their performance in low precision environments. To initialize QAT, BF16 Llama-3.2 model checkpoints obtained after supervised fine-tuning (SFT) are utilized and an additional full round of SFT training with QAT is performed. The backbone of the QAT model is then frozen and another round of SFT is performed with low-rank adaptation (LoRA) adaptors applied to all layers within the transformer block. Meanwhile, the LoRA adaptors' weights and activations are maintained in BF16. 61 62### Accuracy 63 64Please see the [Llama 3.2 model card](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) for accuracy evalations. 65 66### Performance 67 68Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone) with prompt length of 64. It is measured with KleidiAI library. KleidiAI is not enabled by default yet. Use `-DEXECUTORCH_XNNPACK_ENABLE_KLEIDI=ON` to enable it in the build. 69 70|Model | Decode (tokens/s) | Time-to-first-token (sec) | Prefill (tokens/s) | Model size (PTE file size in MiB) | Memory size (RSS in MiB) | 71|-------|------------------:|--------------------------:| ------------------:|----------------------------------:| ------------------------:| 72|1B BF16 (baseline) | 19.2 | 1.0 | 60.3 | 2,358 | 3,185 | 73|1B SpinQuant | 50.2 (2.6x) | 0.3 (-76.9%) | 260.5 (4.3x) | 1,083 (-54.1%) | 1,921 (-39.7%) | 74|1B QAT+LoRA | 45.8 (2.4x) | 0.3 (-76.0%) | 252.0 (4.2x) | 1,127 (-52.2%) | 2,255 (-29.2%) | 75|3B BF16 (baseline) | 7.6 | 3.0 | 21.2 | 6,129 | 7,419 | 76|3B SpinQuant | 19.7 (2.6x) | 0.7 (-76.4%) | 89.7 (4.2x) | 2,435 (-60.3%) | 3,726 (-49.8%) | 77|3B QAT+LoRA | 18.5 (2.4x) | 0.7 (-76.1%) | 88.8 (4.2x) | 2,529 (-58.7%) | 4,060 (-45.3%) | 78 79 80<table> 81 <tr> 82 <td> 83 <img src="./Android3_2_1B_bf16.gif" width="300"> 84 <br> 85 <em> Llama3.2 1B, unquantized, BF16 on Android phone. </em> 86 </td> 87 <td> 88 <img src="./Android3_2_3B_SpinQuant.gif" width="300"> 89 <br> 90 <em> 91 Llama3.2 3B, 4bit quantized (SpinQuant) on Android phone 92 </em> 93 </td> 94 </tr> 95</table> 96 97## Llama 3/3.1 8B 98Since Llama 3 8B model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized (PTQ) model. 99 100### Enablement 101 102For Llama 3 8B and Llama3.1 8B, we have verified so far on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+ and OnePlus 12 (with 16GB RAM) by quantizing to 4bit. 103 104### Quantization 105 106We employed PTQ 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. Due to Llama3's vocabulary size, we had to quantize embedding lookup table as well. For these results embedding lookup table was groupwise quantized with 4-bits and group size of 32. 107 108We use [torchao](https://github.com/pytorch/ao) library APIs to define these schemes. 109 110### Accuracy 111 112We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes, with max_seq_length 2048, and limit 1000. 113 114|Model | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256) 115|--------|-----------------| ---------------------- | --------------- 116|Llama 3 8B | 7.9 | 9.4 | 9.7 117 118Please note that LM Eval reports perplexity normalized by word count instead of token count. You may see different perplexity for WikiText from other sources if they implement it differently. More details could be found [here](https://github.com/EleutherAI/lm-evaluation-harness/issues/2301). 119 120### Performance 121 122Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone). 123 124|Device | Groupwise 4-bit (128) | Groupwise 4-bit (256) 125|--------| ---------------------- | --------------- 126|Galaxy S22 | 7.85 tokens/second | 8.4 tokens/second | 127|Galaxy S24 | 10.91 tokens/second | 11.21 tokens/second | 128|OnePlus 12 | 10.85 tokens/second | 11.02 tokens/second | 129 130<p align="center"> 131 <br> 132 <img src="./llama_via_xnnpack.gif" width=300> 133 <br> 134 <em> 135 Llama3.1 8B, 4bit quantized on Android phone 136 </em> 137</p> 138 139[Please visit this section to try it on non-CPU backend, including CoreML, MPS, Qualcomm HTP or MediaTek](non_cpu_backends.md). 140 141# Instructions 142 143## Tested on 144 145- MacOS M1/M2, Linux. 146- For Llama 3 8B, your device may require at least 32GB RAM. If this is a constraint for you, please try the [smaller stories model](./UTILS.md). 147 148## Step 1: Setup 149> :warning: **double check your python environment**: make sure `conda activate <VENV>` is run before all the bash and python scripts. 150 1511. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. For installation run `./install_requirements.sh --pybind xnnpack` 1522. Run `examples/models/llama/install_requirements.sh` to install a few dependencies. 153 154 155## Step 2: Prepare model 156 157### Option A: Download and export Llama3.2 1B/3B model. 158 1591. Download `consolidated.00.pth`, `params.json` and `tokenizer.model` from [Llama website](https://www.llama.com/llama-downloads/) or [Hugging Face](https://huggingface.co/meta-llama/Llama-3.2-1B). For chat use-cases, download the instruct models. 160 1612. Export model and generate `.pte` file. 162 163- Use **original BF16** version, without any quantization. 164``` 165# No quantization 166# Set these paths to point to the downloaded files 167LLAMA_CHECKPOINT=path/to/checkpoint.pth 168LLAMA_PARAMS=path/to/params.json 169 170python -m examples.models.llama.export_llama \ 171 --checkpoint "${LLAMA_CHECKPOINT:?}" \ 172 --params "${LLAMA_PARAMS:?}" \ 173 -kv \ 174 --use_sdpa_with_kv_cache \ 175 -X \ 176 -d bf16 \ 177 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ 178 --output_name="llama3_2.pte" 179``` 180 181- To use **SpinQuant**, here are two ways: 182 - Download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly. 183 - Follow its [instruction](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch) for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint. 184 185``` 186# SpinQuant 187# Set these paths to point to the exported files 188LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth 189LLAMA_PARAMS=path/to/spinquant/params.json 190 191python -m examples.models.llama.export_llama \ 192 --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ 193 --params "${LLAMA_PARAMS:?}" \ 194 --use_sdpa_with_kv_cache \ 195 -X \ 196 --xnnpack-extended-ops \ 197 --preq_mode 8da4w_output_8da8w \ 198 --preq_group_size 32 \ 199 --max_seq_length 2048 \ 200 --output_name "llama3_2.pte" \ 201 -kv \ 202 -d fp32 \ 203 --preq_embedding_quantize 8,0 \ 204 --use_spin_quant native \ 205 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' 206``` 207 208- To use **QAT+LoRA**, download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly by: 209 210``` 211# QAT+LoRA 212# Set these paths to point to the exported files 213LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth 214LLAMA_PARAMS=path/to/qlora/params.json 215 216python -m examples.models.llama.export_llama \ 217 --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ 218 --params "${LLAMA_PARAMS:?}" \ 219 -qat \ 220 -lora 16 \ 221 --preq_mode 8da4w_output_8da8w \ 222 --preq_group_size 32 \ 223 --preq_embedding_quantize 8,0 \ 224 --use_sdpa_with_kv_cache \ 225 -kv \ 226 -X \ 227 --xnnpack-extended-ops \ 228 -d fp32 \ 229 --max_seq_length 2048 \ 230 --output_name "llama3_2.pte" \ 231 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' 232``` 233 234### Option B: Download and export Llama 3 8B instruct model 235 236You can export and run the original Llama 3 8B instruct model. 237 2381. Llama 3 pretrained parameters can be downloaded from [Meta's official Llama 3 repository](https://github.com/meta-llama/llama3/). 239 2402. Export model and generate `.pte` file 241 ``` 242 python -m examples.models.llama.export_llama \ 243 --checkpoint <consolidated.00.pth> \ 244 -p <params.json> \ 245 -kv \ 246 --use_sdpa_with_kv_cache \ 247 -X \ 248 -qmode 8da4w \ 249 --group_size 128 \ 250 -d fp32 \ 251 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ 252 --embedding-quantize 4,32 \ 253 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" 254 ``` 255 Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size. 256 257 258 If you're interested in deploying on non-CPU backends, [please refer the non-cpu-backend section](non_cpu_backends.md) 259 260## Step 3: Run on your computer to validate 261 2621. Build executorch with optimized CPU performance as follows. Build options available [here](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L59). 263 ``` 264 cmake -DPYTHON_EXECUTABLE=python \ 265 -DCMAKE_INSTALL_PREFIX=cmake-out \ 266 -DEXECUTORCH_ENABLE_LOGGING=1 \ 267 -DCMAKE_BUILD_TYPE=Release \ 268 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ 269 -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ 270 -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ 271 -DEXECUTORCH_BUILD_XNNPACK=ON \ 272 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ 273 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ 274 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ 275 -Bcmake-out . 276 277 cmake --build cmake-out -j16 --target install --config Release 278 ``` 279Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions. 280 2812. Build llama runner. 282 ``` 283 cmake -DPYTHON_EXECUTABLE=python \ 284 -DCMAKE_INSTALL_PREFIX=cmake-out \ 285 -DCMAKE_BUILD_TYPE=Release \ 286 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ 287 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ 288 -DEXECUTORCH_BUILD_XNNPACK=ON \ 289 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ 290 -Bcmake-out/examples/models/llama \ 291 examples/models/llama 292 293 cmake --build cmake-out/examples/models/llama -j16 --config Release 294 ``` 295 2963. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama/main.cpp#L18-L40). 297 ``` 298 cmake-out/examples/models/llama/llama_main --model_path=<model pte file> --tokenizer_path=<tokenizer.model> --prompt=<prompt> 299 ``` 300 301To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON` 302 303## Step 4: Run benchmark on Android phone 304 305**1. Build llama runner binary for Android** 306 307*Pre-requisite*: Android NDK (tested with r27b) which can be downloaded from [here](https://developer.android.com/ndk/downloads). Note that the mac binary can be unpackaged and you can locate NDK folder from it. 308 309**1.1 Set Android NDK** 310``` 311export ANDROID_NDK=<path-to-android-ndk> 312``` 313**1.2 Build executorch and associated libraries for android.** 314``` 315cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ 316 -DANDROID_ABI=arm64-v8a \ 317 -DANDROID_PLATFORM=android-23 \ 318 -DCMAKE_INSTALL_PREFIX=cmake-out-android \ 319 -DCMAKE_BUILD_TYPE=Release \ 320 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ 321 -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ 322 -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ 323 -DEXECUTORCH_ENABLE_LOGGING=1 \ 324 -DPYTHON_EXECUTABLE=python \ 325 -DEXECUTORCH_BUILD_XNNPACK=ON \ 326 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ 327 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ 328 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ 329 -Bcmake-out-android . 330 331cmake --build cmake-out-android -j16 --target install --config Release 332``` 333 334**1.2 Build llama runner for android** 335``` 336cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ 337 -DANDROID_ABI=arm64-v8a \ 338 -DANDROID_PLATFORM=android-23 \ 339 -DCMAKE_INSTALL_PREFIX=cmake-out-android \ 340 -DCMAKE_BUILD_TYPE=Release \ 341 -DPYTHON_EXECUTABLE=python \ 342 -DEXECUTORCH_BUILD_XNNPACK=ON \ 343 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ 344 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ 345 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ 346 -Bcmake-out-android/examples/models/llama \ 347 examples/models/llama 348 349cmake --build cmake-out-android/examples/models/llama -j16 --config Release 350``` 351 352**2. Run on Android via adb shell** 353 354*Pre-requisite*: Make sure you enable USB debugging via developer options on your phone 355 356**2.1 Connect your android phone** 357 358**2.2 Upload model, tokenizer and llama runner binary to phone** 359``` 360adb shell mkdir -p /data/local/tmp/llama 361adb push <model.pte> /data/local/tmp/llama/ 362adb push <tokenizer.model> /data/local/tmp/llama/ 363adb push cmake-out-android/examples/models/llama/llama_main /data/local/tmp/llama/ 364``` 365 366**2.3 Run model** 367``` 368adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.model> --prompt \"What is the capital of France?\" --seq_len 120" --warmup=1 369``` 370## Step 6: Build Mobile apps 371 372### iOS 373 374Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-demo-ios.html) to for full instructions on building the iOS LLAMA Demo App. Rename `tokenizer.model` file to `tokenizer.bin` because the demo app looks for the tokenizer file with .bin extension. 375 376### Android 377Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-demo-android.html) to for full instructions on building the Android LLAMA Demo App. 378 379 380## Utility tools for Llama enablement 381 382### Evaluate model accuracy 383 384> Forewarning: Model evaluation without a GPU may take a long time, especially on larger models. 385 386We use [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate model accuracy. 387 388For base models, use the following example command to calculate its perplexity based on WikiText. 389``` 390python -m examples.models.llama.eval_llama \ 391 -c <checkpoint.pth> \ 392 -p <params.json> \ 393 -t <tokenizer.model/bin> \ 394 -kv \ 395 -d <checkpoint dtype> \ 396 --max_seq_len <max sequence length> \ 397 --limit <number of samples> 398``` 399 400For instruct models, use the following example command to calculate its MMLU score. 401``` 402python -m examples.models.llama.eval_llama \ 403 -c <checkpoint.pth> \ 404 -p <params.json> \ 405 -t <tokenizer.model/bin> \ 406 -kv \ 407 -d <checkpoint dtype> \ 408 --tasks mmlu \ 409 --num_fewshot 5 \ 410 --max_seq_len <max sequence length> 411``` 412 413See [Llama utils page](./UTILS.md) page for more advanced use-cases such as fine-tuning and running smaller models for educational purposes, and quick iteration and verification. 414 415# What is coming next? 416## Quantization 417- Enabling FP16 model to leverage smaller groupsize for 4-bit quantization. 418- Enabling GPTQ for 4-bit groupwise quantization 419- Enabling custom quantization 420- Lower bit quantization 421## Models 422- Enabling more generative AI models and architectures. 423## Performance 424- Performance improvement via techniques such as speculative decoding 425- Enabling LLama and other architectures via Vulkan 426- Enabling performant execution of widely used quantization schemes. 427 428# Notes 429This example tries to reuse the Python code, with minimal modifications to make it compatible with current ExecuTorch: 4301. Since ExecuTorch does not support complex Tensor data type, use the customized functions to have rotary embedding with real numbers. Please see [GitHub issue: Support complex data type in ExecuTorch](https://github.com/pytorch/executorch/issues/886). 4312. No CUDA. ExecuTorch is focused on Edge use cases where CUDA is not available on most of the edge devices. 4323. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch. 433 434 435# Common Issues and Mitigations: 436- To clean your build: 437``` 438git clean -xfd 439pip uninstall executorch 440./install_requirements.sh --pybind xnnpack 441 442rm -rf cmake-out 443``` 444- If you encounter `pthread` related issues during link time, add `pthread` in `target_link_libraries` in `CMakeLists.txt` 445- On Mac, if there is linking error in Step 4 with error message like 446``` 4470 0x100823648 __assert_rtn + 72 4481 0x10074bc5c ld::Fixup::applyFixup(ld::Atom const*, ld::LayoutLinkedImage const&, unsigned char*) const + 8268 4492 0x1007de7d8 ___ZN2ld16LayoutExecutable27writeContentWithoutLinkEditENSt3__14spanIhLm18446744073709551615EEEy_block_invoke + 332 4503 0x188cca428 _dispatch_client_callout2 + 20 4514 0x188cde850 _dispatch_apply_invoke3 + 336 4525 0x188cca3e8 _dispatch_client_callout + 20 4536 0x188ccbc68 _dispatch_once_callout + 32 4547 0x188cdeeec _dispatch_apply_invoke_and_wait + 372 4558 0x188cdde9c _dispatch_apply_with_attr_f + 1212 4569 0x188cde08c dispatch_apply + 96 45710 0x1007de9e4 void mapReduce<ld::Atom const*, mach_o::Error>(std::__1::span<ld::Atom const*, 18446744073709551615ul>, unsigned long, void (unsigned long, mach_o::Error&, std::__1::span<ld::Atom const*, 18446744073709551615ul>) block_pointer, void (std::__1::span<mach_o::Error, 18446744073709551615ul>) block_pointer) + 336 45811 0x1007de594 ld::LayoutExecutable::writeContentWithoutLinkEdit(std::__1::span<unsigned char, 18446744073709551615ul>, unsigned long long) + 1180 45912 0x1007e4020 ld::LayoutExecutable::writeToFile(char const*) + 15248 46013 0x1007962e8 main + 9424 461ld: Assertion failed: (extras.otherInstrOffset != 0 && "Kind::arm64_adrp_ldr missing extra info"), function applyFixup, file Fixup.cpp, line 793. 462clang: error: linker command failed with exit code 1 (use -v to see invocation) 463``` 464It's a known issue for Xcode version 15.1. 465Mitigation: update to most recent Xcode version, clean and rebuild. 466