ml-mdm icon indicating copy to clipboard operation
ml-mdm copied to clipboard

SelfAttention_MLX

Open bdeanhardt opened this issue 10 months ago • 5 comments

bdeanhardt avatar Feb 20 '25 21:02 bdeanhardt

bug :( ========================================================================================= test session starts ========================================================================================= platform darwin -- Python 3.11.5, pytest-7.4.0, pluggy-1.0.0 rootdir: /Users/belladeanhardt/ml-mdm configfile: pyproject.toml plugins: cov-5.0.0, anyio-3.5.0, mock-3.14.0 collected 2 items

tests/test_mlx_unet.py .F [100%]

============================================================================================== FAILURES =============================================================================================== ___________________________________________________________________________________ test_pytorch_mlx_self_attention ___________________________________________________________________________________

def test_pytorch_mlx_self_attention():
    """
    Test for feature parity between PyTorch and MLX implementations of SelfAttention.
    We'll test both the basic self-attention and conditional attention scenarios.
    """
    # Define test parameters
    channels = 64  # Number of channels
    batch_size = 2  # Batch size
    spatial_size = 8  # Spatial dimensions (H=W=8)
    cond_dim = 32  # Conditional dimension
    num_heads = 8  # Number of attention heads

    # Create model instances
    pytorch_attn = SelfAttention(
        channels=channels,
        num_heads=num_heads,
        cond_dim=cond_dim,
        use_attention_ffn=True,
    )
    mlx_attn = SelfAttention_MLX(  # Assuming this is your MLX class name
        channels=channels,
        num_heads=num_heads,
        cond_dim=cond_dim,
        use_attention_ffn=True,
    )

    # Set models to evaluation mode
    pytorch_attn.eval()
    mlx_attn.eval()

    # Create test inputs
    # Main input: [B, C, H, W]
    pytorch_input = torch.randn(batch_size, channels, spatial_size, spatial_size)
    # Conditional input: [B, seq_len, cond_dim]
    cond_seq_len = 4
    pytorch_cond = torch.randn(batch_size, cond_seq_len, cond_dim)
    # Conditional mask: [B, seq_len]
    pytorch_cond_mask = torch.ones(batch_size, cond_seq_len)

    # Test PyTorch version
    pytorch_output = pytorch_attn(
        pytorch_input, cond=pytorch_cond, cond_mask=pytorch_cond_mask
    )

    # Convert inputs to MLX format
    mlx_input = mx.array(pytorch_input.numpy())
    mlx_cond = mx.array(pytorch_cond.numpy())
    mlx_cond_mask = mx.array(pytorch_cond_mask.numpy())

    # Test MLX version
  mlx_output = mlx_attn.forward(mlx_input, cond=mlx_cond, cond_mask=mlx_cond_mask)

tests/test_mlx_unet.py:111:


ml_mdm/models/unet_mlx.py:126: in forward h = self.proj_out(h)


self = Conv2d(64, 64, kernel_size=(1,), stride=(1, 1), padding=(0, 0), dilation=1, groups=1, bias=True) x = array([[[[0.275203, 0.260433, 0.149972, ..., 0.170037, 0.294273, 0.0233037], [0.238134, 0.305211, 0.304386, ....0327954, 0.0905073], [0.0745733, 0.266315, 0.0139441, ..., -0.124577, -0.181556, 0.0526186]]]], dtype=float32)

def __call__(self, x):
  y = mx.conv2d(
        x, self.weight, self.stride, self.padding, self.dilation, self.groups
    )

E ValueError: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,64,8,8) and weight: (64,1,1,64)

../anaconda3/lib/python3.11/site-packages/mlx/nn/layers/convolution.py:157: ValueError ---------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------- input tensor for group norm: (2, 8, 8, 64) output tensor for groupnorm / input tensor for self attention: (2, 3, 64, 64) output tensor for self attention: (2, 3, 64, 64)

---------- coverage: platform darwin, python 3.11.5-final-0 ---------- Name Stmts Miss Cover

ml_mdm/init.py 1 0 100% ml_mdm/clis/init.py 0 0 100% ml_mdm/clis/download_tar_from_index.py 198 198 0% ml_mdm/clis/generate_batch.py 139 139 0% ml_mdm/clis/generate_sample.py 225 225 0% ml_mdm/clis/run_torchmetrics.py 120 120 0% ml_mdm/clis/scrape_cc12m.py 62 62 0% ml_mdm/clis/train_parallel.py 157 157 0% ml_mdm/config.py 127 85 33% ml_mdm/diffusion.py 192 135 30% ml_mdm/distributed.py 39 28 28% ml_mdm/generate_html.py 18 18 0% ml_mdm/helpers.py 8 8 0% ml_mdm/language_models/init.py 0 0 100% ml_mdm/language_models/factory.py 68 68 0% ml_mdm/language_models/self_attention.py 5 5 0% ml_mdm/language_models/tokenizer.py 118 105 11% ml_mdm/language_models/transformer.py 5 5 0% ml_mdm/lr_scaler.py 18 18 0% ml_mdm/models/init.py 1 0 100% ml_mdm/models/model_ema.py 41 31 24% ml_mdm/models/nested_unet.py 115 71 38% ml_mdm/models/unet.py 507 368 27% ml_mdm/models/unet_mlx.py 65 7 89% ml_mdm/reader.py 125 89 29% ml_mdm/s3_helpers.py 56 43 23% ml_mdm/samplers.py 354 276 22% ml_mdm/trainer.py 52 52 0% ml_mdm/utils/init.py 0 0 100% ml_mdm/utils/fix_old_checkpoints.py 10 7 30% ml_mdm/utils/simple_logger.py 85 85 0%

TOTAL 2911 2405 17%

======================================================================================= short test summary info ======================================================================================= FAILED tests/test_mlx_unet.py::test_pytorch_mlx_self_attention - ValueError: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,64,8,8) and weight: (64,1,1,64)

bdeanhardt avatar Feb 21 '25 02:02 bdeanhardt

gonna check it, but probably the error is because of mlx.array initialization

gabrielfnayres avatar Feb 21 '25 10:02 gabrielfnayres

@bdeanhardt i think i've got this

Screenshot 2025-02-21 at 14 59 47

gabrielfnayres avatar Feb 21 '25 18:02 gabrielfnayres

@gabrielfnayres awesome!

bdeanhardt avatar Feb 21 '25 20:02 bdeanhardt

@gabrielfnayres awesome!

should i open another PR? @luke-carlson

gabrielfnayres avatar Feb 21 '25 20:02 gabrielfnayres