coremltools
coremltools copied to clipboard
Failed to build the model execution plan using a model architecture file
๐Describing the bug
Hello. I'm trying to convert PyTorch model to Stateful CoreML Model
I wrote this code referred to WWDC 2024 session Mistral-7B model The CoreML file is appear after run, but "Failed to build the model execution plan using a model architecture file" error appears when CoreML Class init
Stack Trace
/opt/homebrew/lib/python3.11/site-packages/transformers/modeling_utils.py:4779: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead
warnings.warn(
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
/Users/kimbuseong/Downloads/zenz-CoreML/convert-to-CoreML-Stateful.py:70: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if past_key.size(-2) > 0:
Torch var valueCache is added again.
Torch var keyCache is added again.
Converting PyTorch Frontend ==> MIL Ops: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 1600/1600 [00:00<00:00, 2510.79 ops/s]
Running MIL frontend_pytorch pipeline: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5/5 [00:00<00:00, 25.71 passes/s]
Running MIL default pipeline: 65%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | 57/88 [00:03<00:02, 13.10 passes/s]
/opt/homebrew/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_unary.py:894: RuntimeWarning: overflow encountered in cast
return input_var.val.astype(dtype=string_to_nptype(dtype_val))
Running MIL default pipeline: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 88/88 [00:06<00:00, 12.66 passes/s]
Running MIL backend_mlprogram pipeline: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 12/12 [00:00<00:00, 62.18 passes/s]
/opt/homebrew/lib/python3.11/site-packages/coremltools/models/model.py:489: RuntimeWarning: You will not be able to run predict() on this Core ML model. Underlying exception message was: {
NSLocalizedDescription = "Failed to build the model execution plan using a model architecture file '/private/var/folders/pz/rmstwmls5ls_0hrn5_jj01kh0000gn/T/tmppa7zpned.mlmodelc/model.mil' with error code: -14.";
}
_warnings.warn(
Model successfully converted and saved as: zenz_v1_cached.mlpackage
To Reproduce
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention, GPT2_ATTENTION_CLASSES
from transformers import AutoTokenizer
import coremltools as ct
from typing import Optional, Tuple
import numpy as np
from transformers.cache_utils import Cache
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class SliceUpdateKeyValueCache(Cache):
def __init__(
self,
shape: Tuple[int, ...],
device="cpu",
dtype=torch.float32
) -> None:
super().__init__()
self.past_seen_tokens: int = 0
self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
def update(
self,
k_state: torch.Tensor,
v_state: torch.Tensor,
layer_idx: int,
slice_indices: torch.LongTensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(slice_indices) != 2:
raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
begin, end = slice_indices
self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
return k_cache, v_cache
def get_seq_length(self, _: int = 0) -> int:
return self.past_seen_tokens
def to_past_key_values(self):
"""Convert the internal cache to a format expected by GPT2."""
return [(self.k_cache[layer], self.v_cache[layer]) for layer in range(self.k_cache.size(0))]
class SliceUpdateGPT2Attention(GPT2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config=config, layer_idx=layer_idx)
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: bool = False,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# ๊ธฐ์กด ์ฝ๋ ์ ์ง
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
if past_key.size(-2) > 0:
key = torch.cat([past_key, key], dim=-2)
value = torch.cat([past_value, value], dim=-2)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, -key.size(-2):]
# ์ดํ
์
๊ฐ์ค์น๋ฅผ ๋ฐํ๋ฐ๋๋ก ์์
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
present = (key, value) if use_cache else None
if output_attentions:
return attn_output, present, attn_weights
else:
return attn_output, present
class StatefulZenz(torch.nn.Module):
def __init__(self, model, max_context_size: int = 256, batch_size: int = 1):
super(StatefulZenz, self).__init__()
GPT2_ATTENTION_CLASSES["sdpa"] = SliceUpdateGPT2Attention
self.model = model
config = self.model.config
self.kv_cache_shape: Tuple[int, ...] = (
config.num_hidden_layers,
batch_size,
config.n_head,
max_context_size,
config.hidden_size // config.num_attention_heads,
)
self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
self.register_buffer("keyCache", self.kv_cache.k_cache)
self.register_buffer("valueCache", self.kv_cache.v_cache)
def _extend_attention_mask(self, attention_mask, past_key_values):
if past_key_values is not None:
past_length = past_key_values[0][0].size(-2)
new_length = past_length + attention_mask.size(-1)
extended_attention_mask = torch.ones(
(attention_mask.size(0), 1, 1, new_length),
dtype=torch.float32,
device=attention_mask.device
)
extended_attention_mask[:, :, :, -attention_mask.size(-1):] = attention_mask
return extended_attention_mask
return attention_mask
@torch.no_grad()
def forward(self, input_ids, attention_mask):
self.kv_cache.past_seen_tokens = attention_mask.shape[-1] - input_ids.shape[-1]
past_key_values = self.kv_cache.to_past_key_values()
outputs = self.model(
input_ids,
attention_mask=self._extend_attention_mask(attention_mask=attention_mask, past_key_values=past_key_values),
past_key_values=past_key_values,
use_cache=True,
output_attentions=True # ์ดํ
์
๊ฐ์ค์น๋ฅผ ๋ฐํ๋ฐ๋๋ก ์ค์
)
return outputs.logits
def convert_model(model_name: str, output_path: str):
# Set up model and tokenizer
GPT2_ATTENTION_CLASSES["sdpa"] = SliceUpdateGPT2Attention
model = GPT2LMHeadModel.from_pretrained(model_name).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Prepare example input
text = "Example sentence"
inputs = tokenizer(text, return_tensors="pt")
# Create stateful model
stateful_zenz = StatefulZenz(model).eval()
# Trace the model with example inputs
example_inputs = (inputs['input_ids'], inputs['attention_mask'])
traced_model = torch.jit.trace(
stateful_zenz,
example_inputs,
check_trace=False # Disable trace checking to avoid minor numerical differences
)
# Convert to CoreML
mlmodel = ct.convert(
traced_model,
inputs=[
ct.TensorType(
name="input_ids",
shape=(1, ct.RangeDim(1, 256)),
dtype=np.float32
),
ct.TensorType(
name="attention_mask",
shape=(1, ct.RangeDim(1, 256)),
dtype=np.float32
)
],
outputs=[
ct.TensorType(
name="output",
dtype=np.float32
)
],
states=[
ct.StateType(
wrapped_type=ct.TensorType(
shape=stateful_zenz.kv_cache_shape,
dtype=np.float16
),
name="keyCache",
),
ct.StateType(
wrapped_type=ct.TensorType(
shape=stateful_zenz.kv_cache_shape,
dtype=np.float16
),
name="valueCache",
),
],
minimum_deployment_target=ct.target.iOS18,
)
mlmodel.save(output_path)
print(f"Model successfully converted and saved as: {output_path}")
# Usage
model_name = "Miwa-Keita/zenz-v1-checkpoints"
convert_model(model_name, "zenz_v1_cached.mlpackage")
System environment (please complete the following information):
- coremltools version: 8.0b2
- OS (e.g. MacOS version or Linux type): Mac OS Version 15.1 Beta (24B5024e)
- Any other relevant version information (e.g. PyTorch or TensorFlow version):
- python 3.11 with homebrew
- torch-2.3.0
- torchvision-0.18.0
- transformers-4.41.0
@Skyline-23 that is a lot of code. Can you give us a more minimal example?
@TobyRoseman All of code is required to run stateful model based on GPT-2. Sorry ๐ข
Official document example says,
converted_model_kvcache = ct.convert(
traced_model_kvcache,
inputs=inputs,
outputs=outputs,
states=states,
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_AND_GPU,
)
I got same error on compute_units=ct.ComputeUnit.ALL, but pass on compute_units=ct.ComputeUnit.CPU_AND_GPU
@lithium0003 It's not working....
compute_units=ct.ComputeUnit.CPU_AND_GPU and
ignore attention_mask with attention_mask = None just before self._attn(), it's maybe pass, but I don't know why it pass.
I'm having the save error with Apple's checkpoint of DepthAnything. It worked a month ago.
@lithium0003 It works after adding attention_mask = None before attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
But, It produce another error
loc("/Users/kimbuseong/Library/Caches/org.python.python/com.apple.e5rt.e5bundlecache/24C5089c/482419A8FF15595378FF575F9BDC33548B8A4933527ED3D8F1364CA6FEF48A51/70D546C0C146A466AD586A6DE692334F73CB4C7D816C6FA63C4AF624CBCB818D.bundle/H13S.bundle/main/main_mps_graph/main_mps_graph.mpsgraphpackage/model_0.mpsgraph":0:0): error: attempting to parse a byte at the end of the bytecode
I think it's error of mps but I don't know how to resolve this error
https://github.com/huggingface/swift-chat/issues/24 I found similar error in swift chat
I am also encountering thisโ notably, if you try to run the mlpackage (from Swift) using .cpuAndNeuralEngine, that triggers it without fail. Also, attempting to run the model through coremltools.optimize.coreml.experimental.linear_quantize_activations during the quantization phase (if you choose to quantize it) will also trigger it.
Removing the state parameters entirely from ct.convert and using a simple torch.nn.Module allows you to use the NE again, but obviously this means you do not get to leverage the new stateful features.
Any updates? coremltools 8.3 released but nothing fixed....
@Skyline-23 - Can you give us a minimal example to reproduce this issue?
I encountered the same problem. I converted causalLM from coreml. There was no error during the conversion, but an error occurred during inference and it could not be located.
`loc("/Users/xx/Library/Caches/python/com.apple.e5rt.e5bundlecache/24F74/949B0E364F41517D0F438CC20F342C46E47495D322383F351EE7D8E90ACB615D/54E9F2B14303BEF0A555D07775BA39A19B60E78FCAFDF5848B3ABBDF62F966E8.bundle/H16S.bundle/main/main_mps_graph/main_mps_graph.mpsgraphpackage/model_0.mpsgraph":0:0): error: attempting to parse a byte at the end of the bytecode
Step 1: token_id=67511, text='์' Step 2: token_id=40853, text='์ฅ' Step 3: token_id=40853, text='์ฅ' Step 4: token_id=40853, text='์ฅ'
`
it is ok using torchscript
ๅฝๅๅทฒๅค็ 0 ไธชtokenใ Step 1: token_id=151667, text='<think>' ๅฝๅๅทฒๅค็ 16 ไธชtokenใ Step 2: token_id=198, text=' ' ๅฝๅๅทฒๅค็ 17 ไธชtokenใ Step 3: token_id=106287, text='ๅฏ' ๅฝๅๅทฒๅค็ 18 ไธชtokenใ Step 4: token_id=3837, text='๏ผ' ๅฝๅๅทฒๅค็ 19 ไธชtokenใ Step 5: token_id=20002, text='็จๆท' ๅฝๅๅทฒๅค็ 20 ไธชtokenใ Step 6: token_id=56007, text='้ฎ'
` kv_cache = coreml_model.make_state()
for step in range(max_new_tokens):
start_time = time.time()
if step == 0:
input_ids_temp = input_ids.numpy().astype(np.int32)
attention_mask_4d = attention_mask[:, :real_token_count].numpy().astype(np.float16)
attention_mask_4d = attention_mask_4d[:, None, None, :]
else:
input_ids_temp = next_token.numpy().reshape(1, 1).astype(np.int32)
# ๅๅปบ4Dๆฉ็ [1, 1, 1, total_seq_len]
attention_mask_4d = np.ones((1, 1, 1, total_seq_len), dtype=np.float16)
model_input = {
"input_ids": input_ids_temp,
"attention_mask": attention_mask_4d
}
outputs = coreml_model.predict(model_input, kv_cache)
logits_np = outputs["logits"]
logits = torch.from_numpy(logits_np)
next_token_logits = logits[0, -1, :] if step == 0 else logits[0, 0, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
next_token_val = next_token.item()`
I am encountering the same error as @xwhboy. Were you able to figure it out? I am using coremltools 8.3.0 on macos 15.6.1