SelfAttention_MLX
bug :( ========================================================================================= test session starts ========================================================================================= platform darwin -- Python 3.11.5, pytest-7.4.0, pluggy-1.0.0 rootdir: /Users/belladeanhardt/ml-mdm configfile: pyproject.toml plugins: cov-5.0.0, anyio-3.5.0, mock-3.14.0 collected 2 items
tests/test_mlx_unet.py .F [100%]
============================================================================================== FAILURES =============================================================================================== ___________________________________________________________________________________ test_pytorch_mlx_self_attention ___________________________________________________________________________________
def test_pytorch_mlx_self_attention():
"""
Test for feature parity between PyTorch and MLX implementations of SelfAttention.
We'll test both the basic self-attention and conditional attention scenarios.
"""
# Define test parameters
channels = 64 # Number of channels
batch_size = 2 # Batch size
spatial_size = 8 # Spatial dimensions (H=W=8)
cond_dim = 32 # Conditional dimension
num_heads = 8 # Number of attention heads
# Create model instances
pytorch_attn = SelfAttention(
channels=channels,
num_heads=num_heads,
cond_dim=cond_dim,
use_attention_ffn=True,
)
mlx_attn = SelfAttention_MLX( # Assuming this is your MLX class name
channels=channels,
num_heads=num_heads,
cond_dim=cond_dim,
use_attention_ffn=True,
)
# Set models to evaluation mode
pytorch_attn.eval()
mlx_attn.eval()
# Create test inputs
# Main input: [B, C, H, W]
pytorch_input = torch.randn(batch_size, channels, spatial_size, spatial_size)
# Conditional input: [B, seq_len, cond_dim]
cond_seq_len = 4
pytorch_cond = torch.randn(batch_size, cond_seq_len, cond_dim)
# Conditional mask: [B, seq_len]
pytorch_cond_mask = torch.ones(batch_size, cond_seq_len)
# Test PyTorch version
pytorch_output = pytorch_attn(
pytorch_input, cond=pytorch_cond, cond_mask=pytorch_cond_mask
)
# Convert inputs to MLX format
mlx_input = mx.array(pytorch_input.numpy())
mlx_cond = mx.array(pytorch_cond.numpy())
mlx_cond_mask = mx.array(pytorch_cond_mask.numpy())
# Test MLX version
mlx_output = mlx_attn.forward(mlx_input, cond=mlx_cond, cond_mask=mlx_cond_mask)
tests/test_mlx_unet.py:111:
ml_mdm/models/unet_mlx.py:126: in forward h = self.proj_out(h)
self = Conv2d(64, 64, kernel_size=(1,), stride=(1, 1), padding=(0, 0), dilation=1, groups=1, bias=True) x = array([[[[0.275203, 0.260433, 0.149972, ..., 0.170037, 0.294273, 0.0233037], [0.238134, 0.305211, 0.304386, ....0327954, 0.0905073], [0.0745733, 0.266315, 0.0139441, ..., -0.124577, -0.181556, 0.0526186]]]], dtype=float32)
def __call__(self, x):
y = mx.conv2d(
x, self.weight, self.stride, self.padding, self.dilation, self.groups
)
E ValueError: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,64,8,8) and weight: (64,1,1,64)
../anaconda3/lib/python3.11/site-packages/mlx/nn/layers/convolution.py:157: ValueError ---------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------- input tensor for group norm: (2, 8, 8, 64) output tensor for groupnorm / input tensor for self attention: (2, 3, 64, 64) output tensor for self attention: (2, 3, 64, 64)
---------- coverage: platform darwin, python 3.11.5-final-0 ---------- Name Stmts Miss Cover
ml_mdm/init.py 1 0 100% ml_mdm/clis/init.py 0 0 100% ml_mdm/clis/download_tar_from_index.py 198 198 0% ml_mdm/clis/generate_batch.py 139 139 0% ml_mdm/clis/generate_sample.py 225 225 0% ml_mdm/clis/run_torchmetrics.py 120 120 0% ml_mdm/clis/scrape_cc12m.py 62 62 0% ml_mdm/clis/train_parallel.py 157 157 0% ml_mdm/config.py 127 85 33% ml_mdm/diffusion.py 192 135 30% ml_mdm/distributed.py 39 28 28% ml_mdm/generate_html.py 18 18 0% ml_mdm/helpers.py 8 8 0% ml_mdm/language_models/init.py 0 0 100% ml_mdm/language_models/factory.py 68 68 0% ml_mdm/language_models/self_attention.py 5 5 0% ml_mdm/language_models/tokenizer.py 118 105 11% ml_mdm/language_models/transformer.py 5 5 0% ml_mdm/lr_scaler.py 18 18 0% ml_mdm/models/init.py 1 0 100% ml_mdm/models/model_ema.py 41 31 24% ml_mdm/models/nested_unet.py 115 71 38% ml_mdm/models/unet.py 507 368 27% ml_mdm/models/unet_mlx.py 65 7 89% ml_mdm/reader.py 125 89 29% ml_mdm/s3_helpers.py 56 43 23% ml_mdm/samplers.py 354 276 22% ml_mdm/trainer.py 52 52 0% ml_mdm/utils/init.py 0 0 100% ml_mdm/utils/fix_old_checkpoints.py 10 7 30% ml_mdm/utils/simple_logger.py 85 85 0%
TOTAL 2911 2405 17%
======================================================================================= short test summary info ======================================================================================= FAILED tests/test_mlx_unet.py::test_pytorch_mlx_self_attention - ValueError: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,64,8,8) and weight: (64,1,1,64)
gonna check it, but probably the error is because of mlx.array initialization
@bdeanhardt i think i've got this
@gabrielfnayres awesome!
@gabrielfnayres awesome!
should i open another PR? @luke-carlson