AudioFlamingo icon indicating copy to clipboard operation
AudioFlamingo copied to clipboard

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

Open rrscholarship opened this issue 10 months ago • 5 comments

Describe the bug Can't import model, running demo code

To Reproduce import package

Expected behavior demo run

Screenshots image

Additional context

RuntimeError Traceback (most recent call last) Cell In[1], line 2 1 import torch ----> 2 from audio_flamingo.model import AudioFlamingo

File /work/van-speech-nlp/jindaznb/jslpnb/multimodal/AudioFlamingo/audio_flamingo/init.py:1 ----> 1 from audio_flamingo.model import ( 2 XCAttention, 3 AudioFlamingoEncoderBlock, 4 AudioFlamingo, 5 ) 8 all = [ 9 "XCAttention", 10 "AudioFlamingoEncoderBlock", 11 "AudioFlamingo", 12 ]

File /work/van-speech-nlp/jindaznb/jslpnb/multimodal/AudioFlamingo/audio_flamingo/model.py:8 6 from torch import einsum, nn 7 from torch.autograd import Function ----> 8 from zeta.nn import audio_to_text, Attention, SwiGLU 9 from zeta.structs import Transformer, Decoder, AutoregressiveWrapper 11 # helper functions

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/init.py:28 25 f = CustomFilter() 26 logger.addFilter(f) ---> 28 from zeta.nn import * 29 from zeta.models import * 30 from zeta.utils import *

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/init.py:1 ----> 1 from zeta.nn.attention import * 2 from zeta.nn.embeddings import * 3 from zeta.nn.modules import *

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/attention/init.py:14 10 from zeta.nn.attention.local_attention_mha import LocalMHA 12 # from zeta.nn.attention.mgqa import MGQA 13 # from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention ---> 14 from zeta.nn.attention.mixture_attention import ( 15 MixtureOfAttention, 16 MixtureOfAutoregressiveAttention, 17 ) 18 from zeta.nn.attention.multi_modal_causal_attention import ( 19 MultiModalCausalAttention, 20 SimpleMMCA, 21 ) 22 from zeta.nn.attention.multihead_attention import MultiheadAttention

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/attention/mixture_attention.py:8 6 from typing import Tuple, Optional 7 from einops import rearrange, repeat, reduce ----> 8 from zeta.models.vit import exists 9 from zeta.structs.transformer import RMSNorm, apply_rotary_pos_emb 11 from zeta.nn.attention.attend import Attend

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/models/init.py:3 1 # Copyright (c) 2022 Agora 2 # Licensed under The MIT License [see LICENSE for details] ----> 3 from zeta.models.andromeda import Andromeda 4 from zeta.models.base import BaseModel 5 from zeta.models.gpt4 import GPT4, GPT4MultiModal

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/models/andromeda.py:4 1 # the best llm ever made 2 from torch.nn import Module ----> 4 from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper 5 from zeta.structs.transformer import ( 6 Decoder, 7 Transformer, 8 ) 11 class Andromeda(Module):

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/structs/init.py:4 2 from zeta.structs.encoder_decoder import EncoderDecoder 3 from zeta.structs.hierarchical_transformer import HierarchicalTransformer ----> 4 from zeta.structs.local_transformer import LocalTransformer 5 from zeta.structs.parallel_transformer import ParallelTransformerBlock 6 from zeta.structs.transformer import ( 7 Decoder, 8 Encoder, 9 Transformer, 10 ViTransformerWrapper, 11 )

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/structs/local_transformer.py:8 6 from zeta.nn.attention.local_attention_mha import LocalMHA 7 from zeta.nn.biases.dynamic_position_bias import DynamicPositionBias ----> 8 from zeta.nn.modules import feedforward_network 9 from zeta.utils.main import eval_decorator, exists, top_k 12 class LocalTransformer(nn.Module):

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/init.py:47 45 from zeta.nn.modules.s4 import s4d_kernel 46 from zeta.nn.modules.h3 import H3Layer ---> 47 from zeta.nn.modules.mlp_mixer import MLPMixer 48 from zeta.nn.modules.leaky_relu import LeakyRELU 49 from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:145 141 # Example input tensor 142 example_input = torch.randn( 143 1, 512, 32, 32 144 ) # Batch size of 1, 512 channels, 32x32 image --> 145 output = mlp_mixer(example_input) 146 print( 147 output.shape 148 )

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:125, in MLPMixer.forward(self, x) 123 x = rearrange(x, "n c h w -> n (h w) c") 124 for mixer_block in self.mixer_blocks: --> 125 x = mixer_block(x) 126 x = self.pred_head_layernorm(x) 127 x = x.mean(dim=1)

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:63, in MixerBlock.forward(self, x) 61 y = self.norm1(x) 62 y = rearrange(y, "n c t -> n t c") ---> 63 y = self.tokens_mlp(y) 64 y = rearrange(y, "n t c -> n c t") 65 x = x + y

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:30, in MLPBlock.forward(self, x) 21 def forward(self, x: torch.Tensor) -> torch.Tensor: 22 """Forward pass of MLPBlock 23 24 Args: (...) 28 torch.Tensor: description 29 """ ---> 30 y = self.dense1(x) 31 y = F.gelu(y) 32 return self.dense2(y)

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None

File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input) 115 def forward(self, input: Tensor) -> Tensor: --> 116 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

rrscholarship avatar Apr 11 '24 06:04 rrscholarship