metnet icon indicating copy to clipboard operation
metnet copied to clipboard

Network not learning

Open ValterFallenius opened this issue 2 years ago • 7 comments

Training loss is not decreasing I have implemented the network in the PR "lightning" branch with pytorch lightning and tried to find any bugs. The network compiles without issues and seems to generate gradients but still network fails to learn anything. I have tried to play around with the learning rate and plot the data at different stages but even with 4 training samples (it should be able to overfit these) it fails to decrease the loss even after 100 epochs...

Here is the training loss plotted: W B Chart 2022-03-05 11_20_29

It seems like it's doing something but not nearly quick enough to overfit the small dataset. Something is wrong...

Hyperparameters: n_samples = 4 hidden_dim=8, forecast_steps=1, input_channels=15, output_channels=6, #512 input_size=112, # 112 n_samples = 100, num_workers = 8, batch_size = 1, learning_rate = 1e-2

Below is a weights&biases grad report. As you can see most gradients are non-zero, I'm not sure why image_encoder has very small gradients for their biases...

wandb: epoch 83 wandb: grad_2.0_norm/head.bias_epoch 0.0746 wandb: grad_2.0_norm/head.bias_step 0.049 wandb: grad_2.0_norm/head.weight_epoch 0.0862 wandb: grad_2.0_norm/head.weight_step 0.081 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_epoch 0.06653 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_step 0.043 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_epoch 0.00017 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_epoch 0.003 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_step 0.0019 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_epoch 0.16387 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_step 0.1125 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_epoch 0.00013 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_epoch 0.00203 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_epoch 0.15237 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_step 0.1151 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_epoch 0.0032 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_step 0.0018 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_epoch 0.00157 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_epoch 0.00497 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_step 0.003 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_epoch 0.11753 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_step 0.0915 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_epoch 0.03763 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_step 0.0277 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_epoch 0.05167 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_step 0.0369 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_epoch 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_step 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_epoch 0.04393 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_step 0.0216 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_epoch 0.04287 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_step 0.027 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_epoch 0.0014 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_epoch 0.00197 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_step 0.001 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_epoch 0.03313 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_step 0.0216 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_epoch 0.00103 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_step 0.0004 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_epoch 0.00353 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_step 0.002 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_epoch 0.00133 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_epoch 0.02123 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_step 0.0147 wandb: grad_2.0_norm_total_epoch 0.31513 wandb: grad_2.0_norm_total_step 0.2254 wandb: train/loss_epoch 1.72826 wandb: train/loss_step 1.73303 wandb: trainer/global_step 251 wandb: validation/loss_epoch 1.76064

I have plotted the inputs as they flow through the layers, and none of them seems to do anything unexpected: input layer after image_encoder after temp_encoder after agg after head after softmax output vs ground truth

I'm out of ideas and would appreciate any input.

To Reproduce Steps to reproduce the behavior:

  1. Clone https://github.com/ValterFallenius/metnet
  2. Download data samples from README link
  3. Install requirements
  4. Run the code

ValterFallenius avatar Mar 05 '22 11:03 ValterFallenius

Great work plotting all this information!

I'm afraid I'm not familiar enough with the code or the model to really help. But, in the meantime, would it be at all possible to share a link to Weights and Biases for this training run, please? To help us debug.

JackKelly avatar Mar 08 '22 11:03 JackKelly

Here is a link to this particular run: w&b-run

The author Casper mentioned the loss seemed too big. Right now I'm using torch.nn.CrossEntropyLoss(y_hat-y). Is this the same they use in the paper?

See his comment below:

My question:

My network doesn't seem to be able to train properly but I am having a hard time finding the bug. I am trying to run it like you suggested on a small subset of the data, with only a single lead time and much fewer hidden layers but it doesn't train well. The pytorch model compiles and doesn't report any bugs but the network still won't reduce the error even when run repeatedly on the same training sample .

His answer:

If the plot is showing log loss the number seems too high e.g. your probability mass on the correct class is on average exp(-1.77) ≈ 1e-77, just prediting no-rain all the time should give you something much better. However one thing to note is log-loss changes after the first few hundred updates are usually super small. Most likely you have a sign error in the loss or not updating the params as the grads are not zero? I can't really help beyond that.

ValterFallenius avatar Mar 08 '22 14:03 ValterFallenius

Sorry for my chip in. @ValterFallenius Wondering where did you get the sample data? both openclimatefix/goes and openclimatefix/mrms in huggingface or somewhere else? Mind sharing me a bit? Wanna give a try as well. Thanks in advance.

bmkor avatar Apr 24 '22 08:04 bmkor

Hey @bmkor,

I am using neither actually, you can find my raw data in #27. However I have not published elevation, longitude/latitude data I have used, let me know if you need it. But unless you are writing a thesis for the Swedish government I think you might be better off using the original dataset available on huggingface ^^

Also I am not using any GOES data, since it's of bad quality in Sweden because of the lack of geostationary satellites.

/Valter

ValterFallenius avatar Apr 24 '22 08:04 ValterFallenius

Thanks a lot for your prompt reply and comment. Would try to use those available in the huggingface first. See if I can make the model run.

bmkor avatar Apr 24 '22 08:04 bmkor

Thanks a lot for your prompt reply and comment. Would try to use those available in the huggingface first. See if I can make the model run.

Hi, just so you know, the goes dataset currently doesn't have data in it, I'm working through adding data for that. The MRMS dataset does, although I am still finishing up the dataset script. But if you want to get started with that radar data, you can just download the Zarr files themselves and open them locally quite easily.

jacobbieker avatar Apr 24 '22 12:04 jacobbieker

hello, can you tell me where can download MRMS dataset? thank you vary much!

CUITCHENSIYU avatar Jul 14 '24 12:07 CUITCHENSIYU