open_clip icon indicating copy to clipboard operation
open_clip copied to clipboard

add audio spectrogram transformer, and full audio clip

Open lucidrains opened this issue 1 year ago • 40 comments

for building out MuLaN

import torch
from src.open_clip.transformer import AudioSpectrogramTransformer

model = AudioSpectrogramTransformer(
    image_size = 256,
    patch_size = 16,
    width = 512,
    heads = 8,
    mlp_ratio = 4,
    layers = 1,
    output_dim = 512
)

wav = torch.randn(1, 1024)

embed = model(wav) # (1, 512)

Now one can do


import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg

mulan = AudioCLIP(
    embed_dim = 512,
    audio_cfg = CLIPAudioCfg(),
    text_cfg = CLIPTextCfg()
)

wav = torch.randn(2, 1024)
text = torch.randint(0, 10, (2, 77))

audio_latents, text_latents, _  = mulan(wav, text)

print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)

lucidrains avatar Feb 03 '23 20:02 lucidrains

@lucidrains awesome, we should probably put this audio specific stuff in a new file, was thinking of splitting the other sub-transformers at some point too ... audio_transformer.py ?

rwightman avatar Feb 04 '23 22:02 rwightman

@rwightman sure, by modality, or by functionality, or both, either way is fine just let me know

lucidrains avatar Feb 04 '23 23:02 lucidrains

will still need to add the functions for generating from cfg as well as the full AudioClip

perhaps by modality is good

lucidrains avatar Feb 04 '23 23:02 lucidrains

yeah was thinking modality, leave base transformer as the parent, and split off modality specific transformers, at least in this case audio since it's new, can split the others later as other PR are probably based on current structure

rwightman avatar Feb 04 '23 23:02 rwightman

You got it, will make the changes next week

lucidrains avatar Feb 04 '23 23:02 lucidrains

Have a bunch of meetings with people around the valley this week, I'll get around to finishing this next week

lucidrains avatar Feb 06 '23 20:02 lucidrains

Hi @lucidrains the current code looks great! Feel free to ping us (Ke and I) when you are finished!

lukewys avatar Feb 06 '23 22:02 lukewys

Hi @lucidrains Currently we briefly scanned your code and it looks great to us. After you finish the code, just let us know. We will go mainly over the spec-augment (time masking, freq masking, screeching, etc.) and hyperparameters on the spectrogram transformer. If you provide me the specific location in your code, that would be better.

Thanks!

RetroCirce avatar Feb 06 '23 22:02 RetroCirce

@lukewys @RetroCirce Hello Yusong and Ke! Thank you so much for offering your audio expertise; it is more helpful than you realize

The hyperparameters that I am unsure about are listed here to here. But also whatever you think are reasonable default values would be good too!

lucidrains avatar Feb 07 '23 17:02 lucidrains

also, decided to keep a lot of the image in there, in case there is a lot of logic in the library using the encode_image or accessing .visual. we are technically treating the audio as a 2d image (time and frequency) anyways

lucidrains avatar Feb 18 '23 02:02 lucidrains

Hi @lucidrains ! Can you use riffusion spectrogram as input in the encode_image function?

marianna13 avatar Mar 04 '23 07:03 marianna13

@marianna13 oh hey Marianna! good to hear from you

yes, it should be able to accept spectrograms (you just have to pass in a tensor of shape batch, freqs, time)

lucidrains avatar Mar 04 '23 16:03 lucidrains

@marianna13 can you make sure the following code can run

import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg

mulan = AudioCLIP(
    embed_dim = 512,
    audio_cfg = CLIPAudioCfg(),
    text_cfg = CLIPTextCfg()
)

spectrogram = torch.randn(2, 32, 1024)
text = torch.randint(0, 10, (2, 77))

audio_latents, text_latents, _  = mulan(spectrogram, text)

print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)

lucidrains avatar Mar 04 '23 17:03 lucidrains

@lucidrains no, unfortunately I get this error: RuntimeError: Given groups=1, weight of size [768, 3, 16, 16], expected input[2, 1, 32, 1024] to have 3 channels, but got 1 channels instead Can you also please tell me if it's possible to run encode_image over a batch of images? I have found out that the input should have 3 dimensions, right? https://github.com/lucidrains/open_clip/blob/audio-compatible/src/open_clip/audio.py#L107

marianna13 avatar Mar 04 '23 18:03 marianna13

@marianna13 ohh, what is the shape of the input tensor you are passing in? i thought spectrograms only have 1 channel, but i am not really an audio expert

lucidrains avatar Mar 04 '23 19:03 lucidrains

@marianna13 i can make it accommodate 3 channels, if that is the case

lucidrains avatar Mar 04 '23 19:03 lucidrains

@marianna13

import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg

mulan = AudioCLIP(
    embed_dim = 512,
    audio_cfg = CLIPAudioCfg(channels = 3),
    text_cfg = CLIPTextCfg(),
)

spectrogram = torch.randn(2, 3, 32, 1024)
text = torch.randint(0, 10, (2, 77))

audio_latents, text_latents, _  = mulan(spectrogram, text)

print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)

lucidrains avatar Mar 04 '23 19:03 lucidrains

@lucidrains no, unfortunately I get this error: RuntimeError: Given groups=1, weight of size [768, 3, 16, 16], expected input[2, 1, 32, 1024] to have 3 channels, but got 1 channels instead Can you also please tell me if it's possible to run encode_image over a batch of images? I have found out that the input should have 3 dimensions, right? https://github.com/lucidrains/open_clip/blob/audio-compatible/src/open_clip/audio.py#L107

hmm, how are you testing this? are you checking out the entire PR? this error may also suggest you don't have the necessary changes to the vision transformer (to be able to configure it to have 1 channel)

lucidrains avatar Mar 04 '23 19:03 lucidrains

@lucidrains I checked again now it works! (I just forgot that I've made changes to the code) sorry, that's my bad!

marianna13 avatar Mar 04 '23 19:03 marianna13

@marianna13 oh great! can you confirm that you are using 1 channel then? i should revert that commit

lucidrains avatar Mar 04 '23 19:03 lucidrains

@marianna13 i'll add the MulanCoCa version tomorrow too, so we can possibly leap frog the state of the art going on within google

lucidrains avatar Mar 04 '23 19:03 lucidrains

@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :(

marianna13 avatar Mar 04 '23 20:03 marianna13

@marianna13 i'll add the MulanCoCa version tomorrow too, so we can possibly leap frog the state of the art going on within google

That's great! Thank you :)

marianna13 avatar Mar 04 '23 20:03 marianna13

@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :(

oh, that's odd, what is the shape of the batch of images you are sending in?

lucidrains avatar Mar 04 '23 20:03 lucidrains

@marianna13 if you can show me a reproducible error like the sample script above, i can fix it

lucidrains avatar Mar 04 '23 20:03 lucidrains

Hi @lucidrains ! Sorry for the late reply. Here's the code I'm using:

import torch
import cv2
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
import webdataset as wds
import sys
import os
from torchvision import transforms
from PIL import Image
import numpy as np
import time

transform = transforms.Compose([
    transforms.ToTensor()
])


def preprocess(sample:tuple):
    image, json_data = sample
    # json_data = json.loads(json_data.decode())
   
    audio_meta = json_data.get('audio_meta', None)
    
    if audio_meta is not None:
        tags = audio_meta.get('tags', None)
        if tags is not None:
            try:
                title, artist, genre = '', '', ''
                for k in tags.keys():
                    if k in ['title', 'TITLE']:
                        title = f'titled {tags[k]}'
                    if k in ['artist', 'ARTIST']:
                        artist = f'by {tags[k]}'
                    if k in ['genre', 'GENRE']:
                        genre = tags[k]

                label = f'{genre} song "{title}" {artist}'
            except:
                pass
    label = f'{json_data["caption"]}'
    

    return image, {'label': label}

def get_dataset(urls: list):
    '''
    Pass s3 urls and get processed torch dataset
    '''
    dataset = (
           wds.WebDataset(urls)
           .decode("pil")
           .to_tuple("jpg", "json")
           .map_tuple(transform)
           .map(preprocess)
    )
    return dataset

urls = [f'{i:05}.tar' for i in range(1)]
dataset = get_dataset(urls)

batch_size = 32

loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)


mulan = AudioCLIP(
    embed_dim = 32,
    audio_cfg = CLIPAudioCfg(**{'image_size': (512, 1001), 'patch_size': 16}),
    text_cfg = CLIPTextCfg()
)


for i, batch in enumerate(loader):
    im, label = batch
    print(type(mulan.encode_image(im)))

The one example of the dataset can be found here: https://drive.google.com/file/d/15VFMSovEWCHJcDeg9lXqFnlACJmi5gr5/view?usp=sharing

Thank you!

marianna13 avatar Mar 05 '23 09:03 marianna13

@marianna13 hey Marianna, thanks for sharing the script

it looks good except for the image dimensions, whose height and width needs to be divisible by the patch size. however, that assert should be in the code somewhere, maybe left for a separate PR. it also does not matter for the vision transformer other than generating the absolute positions, so long as the image dimensions are the maximum of what you send in during training. The spectrogram must be of fixed shape during training as well, for now

Could you try rerunning your script? And also insert a print statement before the mulan invocation, in the case that it fails again even with my recent changes?

for i, batch in enumerate(loader):
    im, label = batch
    print('input shape is:', im.shape)
    print(type(mulan.encode_image(im)))

lucidrains avatar Mar 05 '23 17:03 lucidrains

@lucidrains it works! Thank you! :)

marianna13 avatar Mar 06 '23 06:03 marianna13

@marianna13 hey Marianna, were you able to do a small test run?

if we can even get a training run to overfit on a small training set, maybe we can try to get this PR merged

lucidrains avatar Mar 18 '23 15:03 lucidrains

Hey @lucidrains, I tried to train a model with a small fraction of the dataset but it gets stuck at the first epoch and then gets killed. I can post my training script (I think it might be an issue on my side) but anyway

marianna13 avatar Mar 18 '23 17:03 marianna13