audio-diffusion-pytorch
audio-diffusion-pytorch copied to clipboard
could provide a example recipe?
- Based on an open source dataset
- Detailed training parameters
Is the following example right?
import librosa
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
from loguru import logger
from audio_diffusion_pytorch import DiffusionModel
from audio_diffusion_pytorch import UNetV0
from audio_diffusion_pytorch import VDiffusion
from audio_diffusion_pytorch import VSampler
LEN = 2 ** 18
class AudioDataset(Dataset):
def __init__(self, fpath):
self.file_list =[]
for line in open(fpath):
line = line.strip()
if not line:
continue
self.file_list.append(line)
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
audio_file = self.file_list[idx]
audio, fs = torchaudio.load(audio_file)
transform = torchaudio.transforms.Resample(fs, 48000)
audio = transform(audio)
if audio.shape[1] > LEN:
offset = np.random.randint(0, audio.shape[1] - LEN)
else:
offset = 0
return audio[:, offset:offset + LEN]
def collate_fn(batch):
bsz = len(batch)
out = torch.zeros(bsz, 2, LEN)
for i, x in enumerate(batch):
out[i, :, :x.shape[1]] = x # torch.from_numpy(x)
return out
model = DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
in_channels=2, # U-Net: number of input/output (audio) channels
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler, # The diffusion sampler used
)
model.to('cuda:0')
train_dataset = AudioDataset('data/train.list')
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=8,
pin_memory=True
)
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""))
for i in range(6):
for audio in tqdm(train_dataloader, desc='epoch %d' % i):
audio = audio.to('cuda:0')
loss = model(audio)
logger.info('loss = %f' % loss.item())
loss.backward()
torch.save(model.state_dict(), 'model_%d.pt' % i)
Looks like it's missing the optimizer, I'd suggest to follow a basic pytorch tutorial on how to setup the training loop.
@flavioschneider, Can you please provide an example script on how to train the model? And also how to get a dataset to train the model?
Is the following example right?
import librosa import numpy as np import torch import torchaudio from torch.utils.data import Dataset from tqdm import tqdm from loguru import logger from audio_diffusion_pytorch import DiffusionModel from audio_diffusion_pytorch import UNetV0 from audio_diffusion_pytorch import VDiffusion from audio_diffusion_pytorch import VSampler LEN = 2 ** 18 class AudioDataset(Dataset): def __init__(self, fpath): self.file_list =[] for line in open(fpath): line = line.strip() if not line: continue self.file_list.append(line) def __len__(self): return len(self.file_list) def __getitem__(self, idx): audio_file = self.file_list[idx] audio, fs = torchaudio.load(audio_file) transform = torchaudio.transforms.Resample(fs, 48000) audio = transform(audio) if audio.shape[1] > LEN: offset = np.random.randint(0, audio.shape[1] - LEN) else: offset = 0 return audio[:, offset:offset + LEN] def collate_fn(batch): bsz = len(batch) out = torch.zeros(bsz, 2, LEN) for i, x in enumerate(batch): out[i, :, :x.shape[1]] = x # torch.from_numpy(x) return out model = DiffusionModel( net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) in_channels=2, # U-Net: number of input/output (audio) channels channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer attention_heads=8, # U-Net: number of attention heads per attention item attention_features=64, # U-Net: number of attention features per attention item diffusion_t=VDiffusion, # The diffusion method used sampler_t=VSampler, # The diffusion sampler used ) model.to('cuda:0') train_dataset = AudioDataset('data/train.list') train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=8, pin_memory=True ) logger.remove() logger.add(lambda msg: tqdm.write(msg, end="")) for i in range(6): for audio in tqdm(train_dataloader, desc='epoch %d' % i): audio = audio.to('cuda:0') loss = model(audio) logger.info('loss = %f' % loss.item()) loss.backward() torch.save(model.state_dict(), 'model_%d.pt' % i)
@gandolfxu Did you get the correct script to train the model? If yes please help me out with that. And also let me know about the dataset that you are using.
@flavioschneider Please let me know what kind of dataset can be used to train the model and how it should be structured.
Hi @deepak-newzera I've written a simple training script here: https://github.com/jameshball/audio-diffusion/blob/master/train.py
It uses the LibriSpeech dataset and downloads it when you start the script.
You might need to change the data path defined at the top, and setup or remove the weights and biases (wandb) logging.
I'm currently having an issue with training where I get NaN: https://github.com/archinetai/audio-diffusion-pytorch/issues/52 but at least this code should give you something to start with.
@jameshball Thanks for your help. The dataset you provided seems to be a speech dataset. But I suppose it must be a music dataset, right?
Hi @jameshball , seems like the training script you linked goes to a 404, do you have an updated link for this training loop somewhere you could share?
I've made that repo private now but this is the version of the file I linked:
import torch
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
from audio_data_pytorch import LibriSpeechDataset, AllTransform
SAMPLE_RATE = 16000
BATCH_SIZE = 12
NUM_SAMPLES = 2**18
def create_model():
return DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
in_channels=1, # U-Net: number of input/output (audio) channels
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler, # The diffusion sampler used
)
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
dataset = LibriSpeechDataset(
root="E:/librispeech",
transforms=AllTransform(
random_crop_size=NUM_SAMPLES,
mono=True,
),
)
print(f"Dataset length: {len(dataset)}")
torchaudio.save("test.wav", dataset[0], SAMPLE_RATE)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=True,
)
model = create_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
run_id = wandb.util.generate_id()
if args.run_id is not None:
run_id = args.run_id
print(f"Run ID: {run_id}")
wandb.init(project="audio-diffusion", resume=args.resume, id=run_id)
epoch = 0
step = 0
if args.checkpoint is not None:
checkpoint_path = args.checkpoint
else:
checkpoint_path = f"checkpoint-{run_id}.pt"
if wandb.run.resumed:
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(wandb.restore(checkpoint_path))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
step = epoch * len(dataloader)
scaler = torch.cuda.amp.GradScaler()
model.train()
while epoch < 100:
avg_loss = 0
avg_loss_step = 0
progress = tqdm(dataloader)
for i, audio in enumerate(progress):
optimizer.zero_grad()
audio = audio.to(device)
with torch.cuda.amp.autocast():
loss = model(audio)
avg_loss += loss.item()
avg_loss_step += 1
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
progress.set_postfix(
loss=loss.item(),
epoch=epoch + i / len(dataloader),
)
if step % 500 == 0:
# Turn noise into new audio sample with diffusion
noise = torch.randn(1, 1, NUM_SAMPLES, device=device)
with torch.cuda.amp.autocast():
sample = model.sample(noise, num_steps=100)
torchaudio.save(f'test_generated_sound_{step}.wav', sample[0].cpu(), SAMPLE_RATE)
del sample
gc.collect()
torch.cuda.empty_cache()
wandb.log({
"step": step,
"epoch": epoch + i / len(dataloader),
"loss": avg_loss / avg_loss_step,
"generated_audio": wandb.Audio(f'test_generated_sound_{step}.wav', caption="Generated audio", sample_rate=SAMPLE_RATE),
})
if step % 100 == 0:
wandb.log({
"step": step,
"epoch": epoch + i / len(dataloader),
"loss": avg_loss / avg_loss_step,
})
avg_loss = 0
avg_loss_step = 0
step += 1
epoch += 1
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
wandb.save(checkpoint_path)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--run_id", type=str, default=None)
return parser.parse_args()
if __name__ == "__main__":
main()
It has some hardcoded paths and wandb code that you might need to remove but worked nicely
@jameshball Did you succeed in training the model? Is it producing some sensible outputs? I am also willing to train using your script. But instead of LibriSpeechDataset, I would like to train the model on a set of wav files. If possible, can you guide how it can be done?
@jameshball did you successfully train? if so can you share what you did and learned?