Model state doesn't work with transpose
🐞Describing the bug
This is a follow up on https://github.com/apple/coremltools/issues/2275. Sorry I couldn't find the reopen option in the original issue. To clarify, the issue didn't happen with batch prediction where a list of tensor is used in the input, the input is ONLY one tensor.
I tried to lower a PyTorch LLama model with KV cache to coreML using the latest stateful feature introduced in 8.0.
The export steps succeeded and I could generate a mlpackage, however during runtime, the code failed immediately when constructing the model class. Error message is like: "Fatal error: 'try!' expression unexpectedly raised an error: Error Domain=com.apple.CoreML Code=0 "MIL program input, 'k_cache', not found in Core ML model inputs" UserInfo={NSLocalizedDescription=MIL program input, 'k_cache', not found in Core ML model inputs}"
I debugged a bit and found that a view + transpose combination would cause this issue but couldn't get any more insight why. The code to repro is attached below. Specifically in the code, if I change it from
k = k.view(1, seqlen, 16, 128)
v = v.view(1, seqlen, 16, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
to
k = k.view(1, 16, seqlen, 128)
v = v.view(1, 16, seqlen, 128)
the inference would work
Stack Trace
only the error message, no further stack trace.
To Reproduce
- Minimal code example that can reproduce the error.
import torch.nn as nn
class TestAttention(nn.Module):
def __init__(self):
super().__init__()
self.wk = nn.Linear(2048, 2048, bias=False)
self.wv = nn.Linear(2048, 2048, bias=False)
self.wo = nn.Linear(128, 128, bias=False)
cache_shape = (1, 16, 128, 128)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=torch.float32, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=torch.float32, device="cpu")
)
def forward(
self, embedding
):
bsz, seqlen, _ = embedding.shape
k, v = self.wk(embedding), self.wv(embedding)
k = k.view(1, seqlen, 16, 128)
v = v.view(1, seqlen, 16, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
self.k_cache[:, :, 0 : seqlen] = k
self.v_cache[:, :, 0 : seqlen] = v
return self.wo(self.k_cache)
model_t = TestAttention().eval()
config = PostTrainingQuantizerConfig.from_dict(
{
"module_type_configs": {
torch.nn.Linear: {
"weight_dtype": "int4",
"granularity": "per_channel",
},
}
}
)
quantizer = PostTrainingQuantizer(model_t, config)
quantized_model = quantizer.compress()
inputs = (
torch.rand(1, 48, 16 * 128),
)
traced_model = torch.jit.trace(quantized_model, inputs)
states = [ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, 16, 128, 128),
),
name=v,
) for v in ['k_cache', 'v_cache']]
mlmodel = ct.convert(
traced_model,
inputs = [ct.TensorType(shape=(1, 48, 16 * 128)), ],
outputs = [ ct.TensorType(name="op")],
states=states,
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_AND_NE,
)
note if I change the code
k = k.view(1, seqlen, 16, 128)
v = v.view(1, seqlen, 16, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
to
k = k.view(1, 16, seqlen, 128)
v = v.view(1, 16, seqlen, 128)
it could run successfully during inference time
- xcode code to run the model
func predict_kv_cache() {
let model = try! test_attention().model
guard let x = try? MLMultiArray(shape:[1, 48, 2048], dataType:MLMultiArrayDataType.float16) else {
fatalError("Unexpected runtime error. MLMultiArray")
}
for i in 0...98303 {
x[i] = 0.1
}
let inputs = test_attentionInput(embedding: x)
let state = model.makeState()
try! model.prediction(from: inputs, using: state)
}
it failed immediately when executing
let model = try! test_attention().model
The converted mlprogram is like this
main[CoreML8](%embedding: (1, 48, 2048, fp16)(Tensor),
%k_cache: (1, 16, 128, 128, fp16)(State),
%v_cache: (1, 16, 128, 128, fp16)(State)) {
block145() {
%wk_weight_cast_fp16: (2048, 2048, fp16)(Tensor) = constexpr_blockwise_shift_scale(data=%wk_weight_data_0, scale=%wk_weight_scale_0_to_fp16, name="wk_weight_cast_fp16")
%wv_weight_cast_fp16: (2048, 2048, fp16)(Tensor) = constexpr_blockwise_shift_scale(data=%wv_weight_data_0, scale=%wv_weight_scale_0_to_fp16, name="wv_weight_cast_fp16")
%wo_weight_cast_fp16: (128, 128, fp16)(Tensor) = constexpr_blockwise_shift_scale(data=%wo_weight_data_0, scale=%wo_weight_scale_0_to_fp16, name="wo_weight_cast_fp16")
%linear_0_cast_fp16: (1, 48, 2048, fp16)(Tensor) = linear(x=%embedding, weight=%wk_weight_cast_fp16, bias=%linear_0_bias_0_to_fp16, name="linear_0_cast_fp16")
%linear_1_cast_fp16: (1, 48, 2048, fp16)(Tensor) = linear(x=%embedding, weight=%wv_weight_cast_fp16, bias=%linear_0_bias_0_to_fp16, name="linear_1_cast_fp16")
%k_3_cast_fp16: (1, 48, 16, 128, fp16)(Tensor) = reshape(x=%linear_0_cast_fp16, shape=[1, 48, 16, 128], name="k_3_cast_fp16")
%v_3_cast_fp16: (1, 48, 16, 128, fp16)(Tensor) = reshape(x=%linear_1_cast_fp16, shape=[1, 48, 16, 128], name="v_3_cast_fp16")
%read_state_0: (1, 16, 128, 128, fp16)(Tensor) = read_state(input=%k_cache, name="read_state_0")
%k_cast_fp16: (1, 16, 48, 128, fp16)(Tensor) = transpose(x=%k_3_cast_fp16, perm=[0, 2, 1, 3], name="transpose_1")
%k_cache_internal_tensor_assign_1_cast_fp16: (1, 16, 128, 128, fp16)(Tensor) = slice_update(x=%read_state_0, update=%k_cast_fp16, begin=[0, 0, 0, 0], end=[0, 0, 48, 0], stride=[1, 1, 1, 1], begin_mask=[False, False, False, True], end_mask=[True, True, False, True], squeeze_mask=[False, False, False, False], name="k_cache_internal_tensor_assign_1_cast_fp16")
%coreml_update_state_0: (1, 16, 128, 128, fp16)(Tensor) = coreml_update_state(state=%k_cache, value=%k_cache_internal_tensor_assign_1_cast_fp16, name="coreml_update_state_0")
%read_state_1: (1, 16, 128, 128, fp16)(Tensor) = read_state(input=%v_cache, name="read_state_1")
%v_cast_fp16: (1, 16, 48, 128, fp16)(Tensor) = transpose(x=%v_3_cast_fp16, perm=[0, 2, 1, 3], name="transpose_0")
%v_cache_internal_tensor_assign_1_cast_fp16: (1, 16, 128, 128, fp16)(Tensor) = slice_update(x=%read_state_1, update=%v_cast_fp16, begin=[0, 0, 0, 0], end=[0, 0, 48, 0], stride=[1, 1, 1, 1], begin_mask=[False, False, False, True], end_mask=[True, True, False, True], squeeze_mask=[False, False, False, False], name="v_cache_internal_tensor_assign_1_cast_fp16")
%coreml_update_state_1: (1, 16, 128, 128, fp16)(Tensor) = coreml_update_state(state=%v_cache, value=%v_cache_internal_tensor_assign_1_cast_fp16, name="coreml_update_state_1")
%op: (1, 16, 128, 128, fp16)(Tensor) = linear(x=%coreml_update_state_0, weight=%wo_weight_cast_fp16, bias=%linear_2_bias_0_to_fp16, name="linear_2_cast_fp16")
} -> (%op)
}
System environment (please complete the following information):
- coremltools version: 8.0b1
- OS (e.g. MacOS version or Linux type): running on iphone 15 pro with ios 18
- Any other relevant version information (e.g. PyTorch or TensorFlow version):
Additional context
- Add anything else about the problem here that you want to share.
Are you able to get predictions from your model in Python?
Are you able to get predictions from your model in Python?
Got some issues upgrading my macOS to 15 as the model inference in Python requires it, will add the results when I have it...
Confirmed, IMHO it is H/W limitation due to the same ANE block is used for both slice_update and transpose/permute? Adding "non-optimizable op" like add external tensor helps, depending on tensor size. Some kind of NOP on "build the model execution plan" can fix it?
toggle work_around = True or False in code below
import torch
import torch.nn as nn
import coremltools as ct
class TestAttention(nn.Module):
def __init__(self):
super().__init__()
self.wk = nn.Linear(2048, 2048, bias=False)
self.wv = nn.Linear(2048, 2048, bias=False)
self.wo = nn.Linear(128, 128, bias=False)
cache_shape = (1, 16, 128, 128)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=torch.float32, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=torch.float32, device="cpu")
)
def forward(
self, embedding, zt
):
bsz, seqlen, _ = embedding.shape
k, v = self.wk(embedding), self.wv(embedding)
zt = zt[:, :, :seqlen, :]
if True:
k = k.view(1, seqlen, 16, 128)
v = v.view(1, seqlen, 16, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# set to False for Error, True for work around
work_around = True
# add external tensor
# if transposed tensor is big you might need bigger op.
# possibly tranpose/permute and slice_update are using the same op?
# and it needs to complete execution prior to slice_update?
if work_around:
zk = zt+k
zv = zt+v
else:
zk = k
zv = v
else:
# no transpose
zk = k.view(1, 16, seqlen, 128)
zv = v.view(1, 16, seqlen, 128)
self.v_cache[:, :, 0 : seqlen] = zv
sum = self.k_cache+self.v_cache
zt = torch.zeros(1, 16, 128, 128, dtype=zt.dtype)
return self.wo(sum) + zt
model_t = TestAttention().eval()
inputs = (
torch.rand(1, 48, 16 * 128),
torch.zeros(1, 16, 48, 128)
)
traced_model = torch.jit.trace(model_t, inputs)
states = [ct.StateType(
wrapped_type=ct.TensorType(
shape=(1, 16, 128, 128),
),
name=v,
) for v in ['k_cache', 'v_cache']]
mlmodel = ct.convert(
traced_model,
inputs = [ct.TensorType(shape=(1, 48, 16 * 128)), ct.TensorType(shape=(1, 16, 128, 128))],
outputs = [ ct.TensorType(name="op")],
states=states,
minimum_deployment_target=ct.target.iOS18,
compute_units=ct.ComputeUnit.CPU_AND_NE,
)
mlmodel.save("states-transpose.mlpackage")
mlmodel2 = ct.models.MLModel("states-transpose.mlpackage", compute_units=ct.ComputeUnit.CPU_AND_NE)
state = mlmodel2.make_state()
# Run prediction
inputs = {
"embedding": torch.rand(1, 48, 16 * 128).numpy(),
"zt_1": torch.zeros(1, 16, 128, 128).numpy()
}
predictions = mlmodel2.predict(inputs, state)
# Print output shape and values
print("\nPrediction Results:")
print(f"Output shape: {predictions['op'].shape}")
print(f"Output values (first few):\n{predictions['op'].flatten()[:5]}")
Xcode profile Pic, that its on ANE