interpretable_predictions
interpretable_predictions copied to clipboard
Accelerate training convergence
Depends on #4
I found #4 to be kinda sufficient to get aspect 0 training stably (even with batch size 1024), but not so much for aspect 1 and 2.
One common issue seems to be the following, model picks everything (i.e. high selection) 🡒 lambda_0 exponentially increasing 🡒 lagrange_0 suddenly dominates mse 🡒 model picks nothing (i.e. extremely low selection) 🡒 it then takes a veeery long time to escape from this particular bad local minima (esp. with learning rate already significantly decayed by then).
So this PR is trying to address this issue by
- stopping lr-decay once "selection" drops too low (this proves very effective from the below experiments)
- additionally clamp lambdas from above based on a maximum allowed "lagrange/mse" ratio
I'm also thinking maybe a smarter initialization, e.g. letting the HardKuma layer output the expected selection rate right from the beginning, would help with stabilize the training as well? Did you try that?