vibration_gan
vibration_gan copied to clipboard
Implement GANs to generate time-series signals for imbalanced learning problem. The experiments are conducted using CWRU bearing data.
trafficstars
vibration_gan
Gan for time series vibration signals generation task, to enhance classification accuracy of fault diagnosis model under imbalanced training data.
personal undergraduate thesis
dataset
CWRU bearning data download
environment setup
- python 3.x
- tensorflow 1.15
- keras
- sklearn
- matplotlib
- numpy
data generation
- train gan with limited target signals:
$ python train_gan.py --phase='train' --GAN_type='WGAN-GP' --target='B007' --imbalance_ratio=50
- generate target signals with pretrained gan:
$ python train_gan.py --phase='generate' --checkpoint_dir=which-pretrained-model-in-checkpoint-dir --target='B007' --imbalance_ratio=50
data evaluation
- use mmd.py to compare the difference between real data and generated data
- use tsne.py to get visualization result
- use fault_diagnosis.py to train diagnosis model with balanced dataset (generated by oversampling method - 'GAN', 'SMOTE', 'ADASYN','RANDOM')
$ python fault_diagnosis.py --imbalance_ratio=50 --oversampling_method='GAN' --generated_data_dir='\generated_data\ORDER_minmax_ratio50'
- compare GAN with other oversampling method
$ python fault_diagnosis.py --imbalance_ratio=50 --oversampling_method='ADASYN'
- just train diagnosis model with balanced real dataset
$ python fault_diagnosis.py --imbalance_ratio=1 --oversampling_method='none'
reference
DCGAN_WGAN_WGAN-GP_LSGAN_SNGAN_RSGAN_BEGAN_ACGAN_PGGAN_TensorFlow