ast icon indicating copy to clipboard operation
ast copied to clipboard

Fine tuning AST model to Music Emotion Classification Overfit

Open moonquekes opened this issue 1 year ago • 3 comments

Hi, @YuanGongND,

Perfect job! When I asked the Newbing what's the music classification SOTA model, it told me the AST!

I used your model for music emotion classification, but It's always overfit, the training set loss goes down normally but the test set loss almost stays the same.

I cheated by putting the validation set into the training set when training, I was able to achieve 99% accuracy at the seventh epoch.

The dataset I use is mtg-jamendo-dataset, The dataset includes 56 moods and themes, I have recalculated the mean and std of the training set. whether I use the full 56 labels or combine them into 8 categories and whether I use the audioset pre-training or not, the overfitting described above will occur. image

This is my run.sh

#!/bin/bash
#SBATCH -p gpu
#SBATCH -x sls-titan-[0-2]
#SBATCH --gres=gpu:4
#SBATCH -c 4
#SBATCH -n 1
#SBATCH --mem=48000
#SBATCH --job-name="ast-esc50"
#SBATCH --output=./log_%j.txt
# set -x
# comment this line if not running on sls cluster
# . /data/sls/scratch/share-201907/slstoolchainrc
# source ../../venvast/bin/activate
export TORCH_HOME=../../pretrained_models
export CUDA_VISIBLE_DEVICES=1,2,3
workers=4
prenum=4
class=8
epoch=100
batch_size=32
model=ast
dataset=8-mtg-jamendo
imagenetpretrain=True
audiosetpretrain=False
bal=none
if [ $audiosetpretrain == True ]
then
  lr=1e-5
else
  lr=1e-4
fi
freqm=24
timem=96
mixup=0

fstride=10
tstride=10
# -2.0373452 4.056306
dataset_mean=-2.0373452
dataset_std=4.056306
audio_length=1024
noise=False

metrics=acc
loss=CE
warmup=False
lrscheduler_start=5
lrscheduler_step=1
lrscheduler_decay=0.85

base_exp_dir=./exp/test-${dataset}-f$fstride-t$tstride-imp$imagenetpretrain-asp$audiosetpretrain-b$batch_size-lr${lr}

if [ -d $base_exp_dir ]; then
  echo 'exp exist'
  exit
fi
mkdir -p $base_exp_dir


exp_dir=${base_exp_dir}

# tr_data=/root/data/XMY/ast-master/src/data/train-merge.json
tr_data=/root/data/XMY/ast-master/src/data/tr_va-merge.json
te_data=/root/data/XMY/ast-master/src/data/test-merge.json


python -W ignore ./run.py --model ${model} --dataset ${dataset} \
--data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
--label-csv /root/data/XMY/ast-master/src/data/mtg_class_merge_labels_indices.csv --n_class ${class} \
--lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model False \
--freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
--tstride $tstride --fstride $fstride --imagenet_pretrain $imagenetpretrain --audioset_pretrain $audiosetpretrain \
--metrics ${metrics} --loss ${loss} --warmup ${warmup} --lrscheduler_start ${lrscheduler_start} --lrscheduler_step ${lrscheduler_step} --lrscheduler_decay ${lrscheduler_decay} \
--dataset_mean ${dataset_mean} --dataset_std ${dataset_std} --audio_length ${audio_length} --noise ${noise} --num-workers ${workers} \
--pre_num ${prenum}

# python ./get_esc_result.py --exp_path ${base_exp_dir}

--Mingyu Xiong

moonquekes avatar Feb 03 '24 13:02 moonquekes

Hi, I am facing a similar overfitting issue as yours. Can you tell me how did you solve it if you managed to solve it.

I am fine-tuning the model for real and fake audio classification. I got a better accuracy by reducing the learning rate. But the validation loss is still too high. These are the metrics I am getting.

acc: 0.936561 AUC: 0.980722 Avg Precision: 0.500000 Avg Recall: 1.000000 d_prime: 2.925854 train_loss: 0.018238 valid_loss: 0.534479

aarshilpatel avatar Feb 29 '24 00:02 aarshilpatel

Hi, I am facing a similar overfitting issue as yours. Can you tell me how did you solve it if you managed to solve it.

I am fine-tuning the model for real and fake audio classification. I got a better accuracy by reducing the learning rate. But the validation loss is still too high. These are the metrics I am getting.

acc: 0.936561 AUC: 0.980722 Avg Precision: 0.500000 Avg Recall: 1.000000 d_prime: 2.925854 train_loss: 0.018238 valid_loss: 0.534479

Sorry I haven't solved this overfitting problem yet. For now my guess is that the AST is a large Transformer-based model with high dataset size requirements and the authors pre-trained model is not necessarily useful for our task. Please let me know what you did if you finally solved it. I am currently experimenting with other datasets with simpler models

moonquekes avatar Feb 29 '24 04:02 moonquekes

i have 40+ audio data for deep fake detection classification if anyone of want to collaborate with me just reach out to me I need help to fine tuned this

Rizwanali324 avatar Jul 13 '24 08:07 Rizwanali324