diffusion-classifier icon indicating copy to clipboard operation
diffusion-classifier copied to clipboard

About the diffusion model implementation

Open hadwinn opened this issue 2 years ago • 8 comments

I currently implement classification using my own dataset. If I wanted to make it better, my idea would be to fine-tune a diffusion model with my own dataset as a way to improve the accuracy for a specific dataset. I don't know if there is any problem with my idea, please advise. If it works, how to load my own diffusion model? Thx!

hadwinn avatar Mar 22 '24 12:03 hadwinn

I have implemented loading my own dataset and tried some optimization methods, but there is no improvement in classification accuracy. For example: https://github.com/chaofengc/TexForce Currently the only way to improve the accuracy is by changing the version of Stable Diffusion. Any suggestions or methods please?

hadwinn avatar Apr 15 '24 13:04 hadwinn

How are you fine-tuning your diffusion model?

alexlioralexli avatar Apr 17 '24 05:04 alexlioralexli

Refer to this program: https://webcache.googleusercontent.com/search?q=cache:https://ngwaifoong92.medium.com/how-to-fine-tune-stable-diffusion-using-lora-85690292c6a8 Use this dataset: https://www.heywhale.com/mw/dataset/5e732227c59d610036227d89 train:

accelerate launch train_text_to_image_lora.py \
  --pretrained_model_name_or_path=runwayml/stable-diffusion-v2-0 \
  --train_data_dir="weather" \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --num_train_epochs=100 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --output_dir="output" \
  --validation_prompt="a picture of a rainy day"

Got the weather.safetensors file. Load it using the following command:

def get_sd_model(args):
    if args.dtype == 'float32':
        dtype = torch.float32
    elif args.dtype == 'float16':
        dtype = torch.float16
    else:
        raise NotImplementedError

    
    model_path = r"C:\Users\Hadwin\.cache\huggingface\hub\SD-2-0"
    
    assert args.version in MODEL_IDS.keys()
    model_id = MODEL_IDS[args.version]
    scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
    pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=dtype)
    pipe.enable_xformers_memory_efficient_attention()
    vae = pipe.vae
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    unet = pipe.unet
    
    # load lora weights
    pipe.unet.load_attn_procs(lora_path)
       
    return vae, tokenizer, text_encoder, unet, scheduler

But it didn't work.

hadwinn avatar Apr 20 '24 01:04 hadwinn

May I ask how the DiT model was trained in the paper? If I want to train with my own dataset how do I do it? Refer to Training DiT in https://github.com/facebookresearch/DiT?

hadwinn avatar Apr 23 '24 07:04 hadwinn

What is the accuracy using default SD2.0 vs your LORA fine-tuned version? Are you sure you're changing the prompts correctly?

For the DiT we used in the paper, we downloaded the pre-trained checkpoints from the DiT repo. Using their training code should be able to reproduce this.

alexlioralexli avatar Apr 24 '24 20:04 alexlioralexli

My own dataset is weather with 6 classes and using the default SD2.0 Mean per class acc is about 43%, SD1.5 is about 67% and SD2.1 is about 71%. After I added my own fine-tuning, the accuracy didn't improve significantly, only about 1%.

I also tried DiT and tested it using the official pre-training weights, which came out to 0.23. I then organized the weather dataset and trained DiT for roughly 700 steps, 2 epochs. tested using this weighting and the result was still 0.23. Is this a coincidence?

hadwinn avatar Apr 27 '24 07:04 hadwinn

Hey @hadwinn did it finally work? I'm also planning to do a similar modification.

hritam-98 avatar May 06 '24 23:05 hritam-98

@hritam-98 It hasn't worked out so far. We can exchange ideas.

hadwinn avatar May 07 '24 01:05 hadwinn