coremltools
coremltools copied to clipboard
Is there support to run ANE accelerated loops/while_loop?
Hi, I've been experimenting with while_loop (s), but I haven't had any success making them run accelerated on the ANE, and neither the GPU. Is it even possible for them to run with acceleration?
Here's an example code, the loop uses just a simple counter as exit condition
import numpy as np
import coremltools as ct
import coremltools.converters.mil as mil
from coremltools.converters.mil import Builder as mb
bsize = 4
seqlen = 128
dim = 512
qshape = (bsize, seqlen, dim)
kshape = (dim, dim)
@mb.program(
input_specs=[
mb.TensorSpec(shape=qshape, dtype=mil.input_types.types.fp16),
mb.TensorSpec(shape=kshape, dtype=mil.input_types.types.fp16),
mb.TensorSpec(shape=(1,), dtype=mil.input_types.types.int32),
],
opset_version=mil.builder.AvailableTarget.iOS17,
)
def loop(q, k, l):
i = mb.fill(shape=np.array([1]), value=np.array(0., dtype=np.int32))
start = q
loop_vars = (i, start)
def cond(_i, state):
return mb.less(x=_i, y=l)
def body(_i, state):
_prod = mb.matmul(x=state, y=k, transpose_y=False)
state = mb.sigmoid(x=_prod)
_i = mb.add(x=_i, y=np.ones([1], dtype=np.int32))
return _i, state
loop_vars = mb.while_loop(_cond=cond, _body=body, loop_vars=loop_vars)
return loop_vars
mlmodel = ct.convert(
loop,
compute_units=ct.ComputeUnit.CPU_AND_NE,
compute_precision=ct.precision.FLOAT16,
minimum_deployment_target=ct.target.iOS17,
inputs=[
ct.TensorType(name='q', shape=ct.Shape(shape=qshape)),
ct.TensorType(name='k', shape=ct.Shape(shape=kshape)),
ct.TensorType(name='l', shape=ct.Shape(shape=(seqlen,))),
]
)
q = np.random.normal(scale=0.2, size=qshape).astype(np.float16)
k = np.random.normal(scale=0.2, size=kshape).astype(np.float16)
l = np.array([16]).astype(np.int32)
mlmodel.predict({'q': q, 'k': k, 'l': l})
I don't think Apple wants to directly document or comment on this behaviour because it may change in the future etc.
For now, from what I can tell, the ANE cannot deal with any control flow whatsoever (including if/else
) – it can take a static set of instructions and compute them quickly but that's it.