open_clip
open_clip copied to clipboard
add audio spectrogram transformer, and full audio clip
for building out MuLaN
import torch
from src.open_clip.transformer import AudioSpectrogramTransformer
model = AudioSpectrogramTransformer(
image_size = 256,
patch_size = 16,
width = 512,
heads = 8,
mlp_ratio = 4,
layers = 1,
output_dim = 512
)
wav = torch.randn(1, 1024)
embed = model(wav) # (1, 512)
Now one can do
import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
mulan = AudioCLIP(
embed_dim = 512,
audio_cfg = CLIPAudioCfg(),
text_cfg = CLIPTextCfg()
)
wav = torch.randn(2, 1024)
text = torch.randint(0, 10, (2, 77))
audio_latents, text_latents, _ = mulan(wav, text)
print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)
@lucidrains awesome, we should probably put this audio specific stuff in a new file, was thinking of splitting the other sub-transformers at some point too ... audio_transformer.py ?
@rwightman sure, by modality, or by functionality, or both, either way is fine just let me know
will still need to add the functions for generating from cfg as well as the full AudioClip
perhaps by modality is good
yeah was thinking modality, leave base transformer as the parent, and split off modality specific transformers, at least in this case audio since it's new, can split the others later as other PR are probably based on current structure
You got it, will make the changes next week
Have a bunch of meetings with people around the valley this week, I'll get around to finishing this next week
Hi @lucidrains the current code looks great! Feel free to ping us (Ke and I) when you are finished!
Hi @lucidrains Currently we briefly scanned your code and it looks great to us. After you finish the code, just let us know. We will go mainly over the spec-augment (time masking, freq masking, screeching, etc.) and hyperparameters on the spectrogram transformer. If you provide me the specific location in your code, that would be better.
Thanks!
@lukewys @RetroCirce Hello Yusong and Ke! Thank you so much for offering your audio expertise; it is more helpful than you realize
The hyperparameters that I am unsure about are listed here to here. But also whatever you think are reasonable default values would be good too!
also, decided to keep a lot of the image
in there, in case there is a lot of logic in the library using the encode_image
or accessing .visual
. we are technically treating the audio as a 2d image (time and frequency) anyways
Hi @lucidrains ! Can you use riffusion spectrogram as input in the encode_image
function?
@marianna13 oh hey Marianna! good to hear from you
yes, it should be able to accept spectrograms (you just have to pass in a tensor of shape batch, freqs, time
)
@marianna13 can you make sure the following code can run
import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
mulan = AudioCLIP(
embed_dim = 512,
audio_cfg = CLIPAudioCfg(),
text_cfg = CLIPTextCfg()
)
spectrogram = torch.randn(2, 32, 1024)
text = torch.randint(0, 10, (2, 77))
audio_latents, text_latents, _ = mulan(spectrogram, text)
print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)
@lucidrains no, unfortunately I get this error: RuntimeError: Given groups=1, weight of size [768, 3, 16, 16], expected input[2, 1, 32, 1024] to have 3 channels, but got 1 channels instead
Can you also please tell me if it's possible to run encode_image
over a batch of images? I have found out that the input should have 3 dimensions, right? https://github.com/lucidrains/open_clip/blob/audio-compatible/src/open_clip/audio.py#L107
@marianna13 ohh, what is the shape of the input tensor you are passing in? i thought spectrograms only have 1 channel, but i am not really an audio expert
@marianna13 i can make it accommodate 3 channels, if that is the case
@marianna13
import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
mulan = AudioCLIP(
embed_dim = 512,
audio_cfg = CLIPAudioCfg(channels = 3),
text_cfg = CLIPTextCfg(),
)
spectrogram = torch.randn(2, 3, 32, 1024)
text = torch.randint(0, 10, (2, 77))
audio_latents, text_latents, _ = mulan(spectrogram, text)
print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512)
@lucidrains no, unfortunately I get this error:
RuntimeError: Given groups=1, weight of size [768, 3, 16, 16], expected input[2, 1, 32, 1024] to have 3 channels, but got 1 channels instead
Can you also please tell me if it's possible to runencode_image
over a batch of images? I have found out that the input should have 3 dimensions, right? https://github.com/lucidrains/open_clip/blob/audio-compatible/src/open_clip/audio.py#L107
hmm, how are you testing this? are you checking out the entire PR? this error may also suggest you don't have the necessary changes to the vision transformer (to be able to configure it to have 1 channel)
@lucidrains I checked again now it works! (I just forgot that I've made changes to the code) sorry, that's my bad!
@marianna13 oh great! can you confirm that you are using 1 channel then? i should revert that commit
@marianna13 i'll add the MulanCoCa
version tomorrow too, so we can possibly leap frog the state of the art going on within google
@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :(
@marianna13 i'll add the
MulanCoCa
version tomorrow too, so we can possibly leap frog the state of the art going on within google
That's great! Thank you :)
@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :(
oh, that's odd, what is the shape of the batch of images you are sending in?
@marianna13 if you can show me a reproducible error like the sample script above, i can fix it
Hi @lucidrains ! Sorry for the late reply. Here's the code I'm using:
import torch
import cv2
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
import webdataset as wds
import sys
import os
from torchvision import transforms
from PIL import Image
import numpy as np
import time
transform = transforms.Compose([
transforms.ToTensor()
])
def preprocess(sample:tuple):
image, json_data = sample
# json_data = json.loads(json_data.decode())
audio_meta = json_data.get('audio_meta', None)
if audio_meta is not None:
tags = audio_meta.get('tags', None)
if tags is not None:
try:
title, artist, genre = '', '', ''
for k in tags.keys():
if k in ['title', 'TITLE']:
title = f'titled {tags[k]}'
if k in ['artist', 'ARTIST']:
artist = f'by {tags[k]}'
if k in ['genre', 'GENRE']:
genre = tags[k]
label = f'{genre} song "{title}" {artist}'
except:
pass
label = f'{json_data["caption"]}'
return image, {'label': label}
def get_dataset(urls: list):
'''
Pass s3 urls and get processed torch dataset
'''
dataset = (
wds.WebDataset(urls)
.decode("pil")
.to_tuple("jpg", "json")
.map_tuple(transform)
.map(preprocess)
)
return dataset
urls = [f'{i:05}.tar' for i in range(1)]
dataset = get_dataset(urls)
batch_size = 32
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
mulan = AudioCLIP(
embed_dim = 32,
audio_cfg = CLIPAudioCfg(**{'image_size': (512, 1001), 'patch_size': 16}),
text_cfg = CLIPTextCfg()
)
for i, batch in enumerate(loader):
im, label = batch
print(type(mulan.encode_image(im)))
The one example of the dataset can be found here: https://drive.google.com/file/d/15VFMSovEWCHJcDeg9lXqFnlACJmi5gr5/view?usp=sharing
Thank you!
@marianna13 hey Marianna, thanks for sharing the script
it looks good except for the image dimensions, whose height and width needs to be divisible by the patch size. however, that assert should be in the code somewhere, maybe left for a separate PR. it also does not matter for the vision transformer other than generating the absolute positions, so long as the image dimensions are the maximum of what you send in during training. The spectrogram must be of fixed shape during training as well, for now
Could you try rerunning your script? And also insert a print statement before the mulan invocation, in the case that it fails again even with my recent changes?
for i, batch in enumerate(loader):
im, label = batch
print('input shape is:', im.shape)
print(type(mulan.encode_image(im)))
@lucidrains it works! Thank you! :)
@marianna13 hey Marianna, were you able to do a small test run?
if we can even get a training run to overfit on a small training set, maybe we can try to get this PR merged
Hey @lucidrains, I tried to train a model with a small fraction of the dataset but it gets stuck at the first epoch and then gets killed. I can post my training script (I think it might be an issue on my side) but anyway