AudioFlamingo
AudioFlamingo copied to clipboard
RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)
Describe the bug Can't import model, running demo code
To Reproduce import package
Expected behavior demo run
Screenshots
Additional context
RuntimeError Traceback (most recent call last) Cell In[1], line 2 1 import torch ----> 2 from audio_flamingo.model import AudioFlamingo
File /work/van-speech-nlp/jindaznb/jslpnb/multimodal/AudioFlamingo/audio_flamingo/init.py:1 ----> 1 from audio_flamingo.model import ( 2 XCAttention, 3 AudioFlamingoEncoderBlock, 4 AudioFlamingo, 5 ) 8 all = [ 9 "XCAttention", 10 "AudioFlamingoEncoderBlock", 11 "AudioFlamingo", 12 ]
File /work/van-speech-nlp/jindaznb/jslpnb/multimodal/AudioFlamingo/audio_flamingo/model.py:8 6 from torch import einsum, nn 7 from torch.autograd import Function ----> 8 from zeta.nn import audio_to_text, Attention, SwiGLU 9 from zeta.structs import Transformer, Decoder, AutoregressiveWrapper 11 # helper functions
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/init.py:28 25 f = CustomFilter() 26 logger.addFilter(f) ---> 28 from zeta.nn import * 29 from zeta.models import * 30 from zeta.utils import *
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/init.py:1 ----> 1 from zeta.nn.attention import * 2 from zeta.nn.embeddings import * 3 from zeta.nn.modules import *
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/attention/init.py:14 10 from zeta.nn.attention.local_attention_mha import LocalMHA 12 # from zeta.nn.attention.mgqa import MGQA 13 # from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention ---> 14 from zeta.nn.attention.mixture_attention import ( 15 MixtureOfAttention, 16 MixtureOfAutoregressiveAttention, 17 ) 18 from zeta.nn.attention.multi_modal_causal_attention import ( 19 MultiModalCausalAttention, 20 SimpleMMCA, 21 ) 22 from zeta.nn.attention.multihead_attention import MultiheadAttention
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/attention/mixture_attention.py:8 6 from typing import Tuple, Optional 7 from einops import rearrange, repeat, reduce ----> 8 from zeta.models.vit import exists 9 from zeta.structs.transformer import RMSNorm, apply_rotary_pos_emb 11 from zeta.nn.attention.attend import Attend
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/models/init.py:3 1 # Copyright (c) 2022 Agora 2 # Licensed under The MIT License [see LICENSE for details] ----> 3 from zeta.models.andromeda import Andromeda 4 from zeta.models.base import BaseModel 5 from zeta.models.gpt4 import GPT4, GPT4MultiModal
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/models/andromeda.py:4 1 # the best llm ever made 2 from torch.nn import Module ----> 4 from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper 5 from zeta.structs.transformer import ( 6 Decoder, 7 Transformer, 8 ) 11 class Andromeda(Module):
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/structs/init.py:4 2 from zeta.structs.encoder_decoder import EncoderDecoder 3 from zeta.structs.hierarchical_transformer import HierarchicalTransformer ----> 4 from zeta.structs.local_transformer import LocalTransformer 5 from zeta.structs.parallel_transformer import ParallelTransformerBlock 6 from zeta.structs.transformer import ( 7 Decoder, 8 Encoder, 9 Transformer, 10 ViTransformerWrapper, 11 )
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/structs/local_transformer.py:8 6 from zeta.nn.attention.local_attention_mha import LocalMHA 7 from zeta.nn.biases.dynamic_position_bias import DynamicPositionBias ----> 8 from zeta.nn.modules import feedforward_network 9 from zeta.utils.main import eval_decorator, exists, top_k 12 class LocalTransformer(nn.Module):
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/init.py:47 45 from zeta.nn.modules.s4 import s4d_kernel 46 from zeta.nn.modules.h3 import H3Layer ---> 47 from zeta.nn.modules.mlp_mixer import MLPMixer 48 from zeta.nn.modules.leaky_relu import LeakyRELU 49 from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:145 141 # Example input tensor 142 example_input = torch.randn( 143 1, 512, 32, 32 144 ) # Batch size of 1, 512 channels, 32x32 image --> 145 output = mlp_mixer(example_input) 146 print( 147 output.shape 148 )
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:125, in MLPMixer.forward(self, x) 123 x = rearrange(x, "n c h w -> n (h w) c") 124 for mixer_block in self.mixer_blocks: --> 125 x = mixer_block(x) 126 x = self.pred_head_layernorm(x) 127 x = x.mean(dim=1)
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:63, in MixerBlock.forward(self, x) 61 y = self.norm1(x) 62 y = rearrange(y, "n c t -> n t c") ---> 63 y = self.tokens_mlp(y) 64 y = rearrange(y, "n t c -> n c t") 65 x = x + y
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/zeta/nn/modules/mlp_mixer.py:30, in MLPBlock.forward(self, x) 21 def forward(self, x: torch.Tensor) -> torch.Tensor: 22 """Forward pass of MLPBlock 23 24 Args: (...) 28 torch.Tensor: description 29 """ ---> 30 y = self.dense1(x) 31 y = F.gelu(y) 32 return self.dense2(y)
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs)
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None
File /work/van-speech-nlp/jindaznb/asrenv/lib/python3.10/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input) 115 def forward(self, input: Tensor) -> Tensor: --> 116 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x4 and 512x512)
Upvote & Fund
- We're using Polar.sh so you can upvote and help fund this issue.
- We receive the funding once the issue is completed & confirmed by you.
- Thank you in advance for helping prioritize & fund our backlog.