jax
jax copied to clipboard
jax-metal: Failed assertion...expected element type f32 but received si32
Description
I set up a venv for my project using jax-metal
, but hit the following assertion error when I ran my otherwise functioning code:
/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1650: failed assertion `Incompatible element type for parameter at index 18, mlir module expected element type f32 but received si32'
The error does not come with a stack trace. Disabling jit for the entire script avoided the issue, but likewise didn't help with stack tracing. On a hunch, I found that the issue was somehow related to a no-op Flax module I have that looks like this:
class Foo(nn.Module):
"""No-op module with stub parameters"""
@nn.compact
def __call__(self, x):
# need some sort of unused param for pytree reasons
self.param('null', zeros, 0) # changing 0 -> (0,) does not fix anything
# changing 0 -> (1,) does fix the assertion
return x
As noted in the comment, changing the shape for the unused param does fix the error. However, the assertion is not raised when the module's apply function is called. After stepping through line-by-line with a debugger, I've found that it's raised much later, during the teardown of my jit'd training step function.
Here is a schematic example of my code organization. This code does not reproduce the error, but is intended to show where the assertion gets raised in relation to the module's apply function:
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.linen.initializers import zeros
from flax.training.train_state import TrainState
from jax import random
class Foo(nn.Module):
"""No-op module with stub parameters"""
@nn.compact
def __call__(self, x):
self.param('null', zeros, 0) # changing 0 -> (1,) does fix the assertion
return x
@jax.jit
def fwd(p, s, x):
"""Model forward function"""
y = s.apply_fn(p, x) # apply the problematic module
y = ... # apply multiple other modules and functions
l = jnp.mean((y - x) ** 2)
return l, {'val': y, 'loss': l}
@jax.jit
def step(ss, xx):
"""Training step helper"""
(l, m), grads = jax.value_and_grad(fwd, has_aux=True)(ss.params, ss, xx)
new_state = ss.apply_gradients(grads=grads)
return new_state, l, m # assertion is raised here in original code
# initialize the model
rng = random.PRNGKey(0)
foo = Foo()
params = jax.jit(foo.init)(rng, jnp.ones((1,)))
state = TrainState.create(
apply_fn=jax.jit(foo.apply),
params=params,
tx=optax.sgd(0.01)
)
# run for some steps
for _ in range(10):
state, loss, metrics = step(state, jnp.ones((10, 10)))
Anyway, I'm not sure if this is a bug or intended behavior for Metal. I have a fix, but I would like to understand why that parameter can't be a zero-sized array in my project code when it does work in this example code. Or maybe it's not related to that module? But then, why does changing that module fix the error? It seems similar to #16435, but that also wasn't a zero-sized array issue, so I don't know.
System info (python version, jaxlib version, accelerator, etc.)
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
jax: 0.4.20
jaxlib: 0.4.20
jax-metal : 0.0.5
numpy: 1.26.4
python: 3.10.13 (main, Aug 24 2023, 12:59:26) [Clang 15.0.0 (clang-1500.0.40.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
The issue is not reproducible. Do you still see the same problem with the latest OS 14.4 and jax-metal 0.0.6?
I just updated to macOS 14.4 and jax-metal 0.0.6, and the issue does still occur if I pass 0
or (0,)
to my module initializer, but not if I pass (1,)
.
As I say, the above code is a schematic that does not reproduce the error, and I'm unfortunately not in a position at the moment to start trimming my full code base down to a minimal reproducible example. It'll probably be at least a month or so before I'll have that sort of time.
I understand if you want to close the issue since it's not reproducible. I just wanted to make sure this occurrence was at least documented in case someone else has a similar issue.