roc-star
roc-star copied to clipboard
PyTorch implementation as a class
Hi, thanks a lot for you job, idea described here is smart and simple!
To apply it in my work I had to rewrite your code into class form. This also made code a bit cleaner using PyTorch broadcasting. Hope it will also be useful for someone else =)
I also slightly changed the logic. Here batches are saved into a FIFO queue and during each forward
call last sample_size
elements are taken instead of a random subset.
update: Added small fix from brandenkmurray in update_gamma
function.
class RocStarLoss(_Loss):
"""Smooth approximation for ROC AUC
"""
def __init__(self, delta = 1.0, sample_size = 1000, sample_size_gamma = 10000, update_gamma_each=500):
r"""
Args:
delta: Param from article
sample_size (int): Number of examples to take for ROC AUC approximation
sample_size_gamma (int): Number of examples to take for Gamma parameter approximation
update_gamma_each (int): Number of steps after which to recompute gamma value.
"""
super().__init__()
self.delta = delta
self.sample_size = sample_size
self.sample_size_gamma = sample_size_gamma
self.update_gamma_each = update_gamma_each
self.steps = 0
size = max(sample_size, sample_size_gamma)
# Randomly init labels
self.y_pred_history = torch.rand((size, 1))
self.y_true_history = torch.randint(2, (size, 1))
def forward(self, y_pred, y_true):
"""
Args:
y_pred: Tensor of model predictions in [0, 1] range. Shape (B x 1)
y_true: Tensor of true labels in {0, 1}. Shape (B x 1)
"""
if self.steps % self.update_gamma_each == 0:
self.update_gamma()
self.steps += 1
positive = y_pred[y_true > 0]
negative = y_pred[y_true < 1]
# Take last `sample_size` elements from history
y_pred_history = self.y_pred_history[- self.sample_size:]
y_true_history = self.y_true_history[- self.sample_size:]
positive_history = y_pred_history[y_true_history > 0]
negative_history = y_pred_history[y_true_history < 1]
if positive.size(0) > 0:
diff = negative_history.view(1, -1) + self.gamma - positive.view(-1, 1)
loss_positive = torch.nn.functional.relu(diff ** 2).mean()
else:
loss_positive = 0
if negative.size(0) > 0:
diff = negative.view(1, -1) + self.gamma - positive_history.view(-1, 1)
loss_negative = torch.nn.functional.relu(diff ** 2).mean()
else:
loss_negative = 0
loss = loss_negative + loss_positive
# Update FIFO queue
batch_size = y_pred.size(0)
self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred))
self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true))
return loss
def update_gamma(self):
# Take last `sample_size_gamma` elements from history
y_pred = self.y_pred_history[- self.sample_size_gamma:]
y_true = self.y_true_history[- self.sample_size_gamma:]
positive = y_pred[y_true > 0]
negative = y_pred[y_true < 1]
# Create matrix of size sample_size_gamma x sample_size_gamma
diff = positive.view(-1, 1) - negative.view(1, -1)
AUC = (diff > 0).type(torch.float).mean()
num_wrong_ordered = (1 - AUC) * diff.flatten().size(0)
# Adjuct gamma, so that among correct ordered samples `delta * num_wrong_ordered` were considered
# ordered incorrectly with gamma added
correct_ordered = diff[diff > 0].flatten().sort().values
idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
self.gamma = correct_ordered[idx]
This is great! Also corrals stateful things like gamma into the object where they belong.
Do you have a quick example or unit test of this thing?
This looks worthwhile to merge into the original codebase I'm thinking ...
No, no quick examples, I'm using it in a different domain, not NLP tasks. I'll add tests a bit later, but for now you can adjust your example on README as follows:
train_ds = CatDogDataset(train_files, transform)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE)
roc_star_loss = RocStarLoss()
for epoch in range(epoches):
for X, y in train_dl:
preds = model(X)
# ...
loss = roc_star_loss(y,preds)
#...
Output value of epoch_update_gamma
is same with self.update_gamma()
(which is also ~25% faster )
Loss value can be slightly different for small sample_size
, but always roughly similar
Works for me, thanks! Plz let me know if this loss function works well in your problem domain!
Oh - important note; I have to update the text, I've actually found that 2.0 is the best default for Delta.
(Feel free to turn that knob though!)
Looks like delta
value mostly depends on top achievable performance. For example if your regular AUC is bouncing near ~0.75 it doesn't make sense to have big value, while if your AUC is near to 0.95 even bigger values of delta
can be used to further shift prediction probabilities.
Interesting; will have to keep an eye on that issue. I've wondered if gamma should never drop below some minimum value, or if delta*wrong_pairs should be clipped above some miminum ...
(I'll break this code out from the paper soon, as a proper source file, and then we can do a proper gitful checkin of your changes.)
I often get Index errors in this line when delta > 1
because sometimes the calculated index is greater than len(correct_ordered)
:
correct_ordered[int(num_wrong_ordered * self.delta)]
I've kind of fixed this by changing it to the below, but as I don't fully understand the code I'm not sure if that's a proper fix or if there is some other root issue that needs to be fixed:
correct_ordered = diff[diff > 0].flatten().sort().values
# Deltas > 1 can cause an index error if the resulting value is more than len(correct_ordered)
correct_ordered_idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
self.gamma = correct_ordered[correct_ordered_idx]
@brandenkmurray Yes, your fix is valid. I also did something similar in code I use. That's often happens at the very begging when AUC is ~0.5-0.6. Also reducing delta
value helps, but I didn't investigate how this influences the result.
Anyone have a toy kernel some where I can fork to take a look at this bug?
I have one in private repo 😞 Will share at mid-August, as far as competition I'm taking part in will end. Nothing really serious here. Just an indexing error I didn't notice at the beginning
Cool - good luck in competition! If the loss function can help you rank high, that would be awesome!
Very nice work, running into "RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time." when using this though. Will stick to the normal integration in the repo for now, which runs fine .... or did you see any performance improvements using this implementation instead of the normal?
Hey I tried using this class on a toy model + data and I get this error
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
when I do loss.backward()
I get another error when I do loss.backward(retain_graph=True), I get another error.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of TBackward, is at version 10; expected version 9 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
I have attached my collab notebook so you can recreate the error. Any help would be appreciated.
Hi
I actually never used this class for training, ended up using another approach, so the code is definitely needs some debugging.
Try using hint from pytorch torch.autograd.set_detect_anomaly(True)
I can guess error happens somewhere in FIFO queue, but don't have time to debug it now.
Please comment, if you fix the issue!
Thanks for the Colab notebook, @PotatoSpudowski . Taking a look right now ...
Hi guys - here's the fix for @PotatoSpudowski 's error : "gradient computation has been modified by an inplace operation".
At the end of RocStarLoss.forward(), replace
- self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred))
- self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true))
with
- self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred.clone().detach()))
- self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true.clone().detach()))
This is in addition to calling the loss.backward method like this : loss.backward(retain_graph=True)
I'm going to (finally!) pull this into the codebase tonight, for now let me just give a complete version of the working class :
class RocStarLoss(_Loss):
"""Smooth approximation for ROC AUC
"""
def __init__(self, delta = 1.0, sample_size = 10, sample_size_gamma = 10, update_gamma_each=10):
r"""
Args:
delta: Param from article
sample_size (int): Number of examples to take for ROC AUC approximation
sample_size_gamma (int): Number of examples to take for Gamma parameter approximation
update_gamma_each (int): Number of steps after which to recompute gamma value.
"""
super().__init__()
self.delta = delta
self.sample_size = sample_size
self.sample_size_gamma = sample_size_gamma
self.update_gamma_each = update_gamma_each
self.steps = 0
size = max(sample_size, sample_size_gamma)
# Randomly init labels
self.y_pred_history = torch.rand((size, 1))
self.y_true_history = torch.randint(2, (size, 1))
def forward(self, y_pred, y_true):
"""
Args:
y_pred: Tensor of model predictions in [0, 1] range. Shape (B x 1)
y_true: Tensor of true labels in {0, 1}. Shape (B x 1)
"""
#y_pred = _y_pred.clone().detach()
#y_true = _y_true.clone().detach()
if self.steps % self.update_gamma_each == 0:
self.update_gamma()
self.steps += 1
positive = y_pred[y_true > 0]
negative = y_pred[y_true < 1]
# Take last `sample_size` elements from history
y_pred_history = self.y_pred_history[- self.sample_size:]
y_true_history = self.y_true_history[- self.sample_size:]
positive_history = y_pred_history[y_true_history > 0]
negative_history = y_pred_history[y_true_history < 1]
if positive.size(0) > 0:
diff = negative_history.view(1, -1) + self.gamma - positive.view(-1, 1)
loss_positive = torch.nn.functional.relu(diff ** 2).mean()
else:
loss_positive = 0
if negative.size(0) > 0:
diff = negative.view(1, -1) + self.gamma - positive_history.view(-1, 1)
loss_negative = torch.nn.functional.relu(diff ** 2).mean()
else:
loss_negative = 0
loss = loss_negative + loss_positive
# Update FIFO queue
batch_size = y_pred.size(0)
self.y_pred_history = torch.cat((self.y_pred_history[batch_size:], y_pred.clone().detach()))
self.y_true_history = torch.cat((self.y_true_history[batch_size:], y_true.clone().detach()))
return loss
def update_gamma(self):
# Take last `sample_size_gamma` elements from history
y_pred = self.y_pred_history[- self.sample_size_gamma:]
y_true = self.y_true_history[- self.sample_size_gamma:]
positive = y_pred[y_true > 0]
negative = y_pred[y_true < 1]
# Create matrix of size sample_size_gamma x sample_size_gamma
diff = positive.view(-1, 1) - negative.view(1, -1)
AUC = (diff > 0).type(torch.float).mean()
num_wrong_ordered = (1 - AUC) * diff.flatten().size(0)
# Adjuct gamma, so that among correct ordered samples `delta * num_wrong_ordered` were considered
# ordered incorrectly with gamma added
correct_ordered = diff[diff > 0].flatten().sort().values
idx = min(int(num_wrong_ordered * self.delta), len(correct_ordered)-1)
self.gamma = correct_ordered[idx]
This works thank you!
One more thing that I noticed was that sometimes loss is nan
Epoch 046: | Loss: 0.02856 | Acc: 65.750 Epoch 047: | Loss: nan | Acc: 63.500 Epoch 048: | Loss: 0.02594 | Acc: 52.000 Epoch 049: | Loss: 0.02916 | Acc: 65.750
And when we try to update gamma we get an error.
`
ValueError: cannot convert float NaN to integer`
I believe this can be fixed by taking a larger sample but it would be nice if had a check condition before calling update_gamma()
I will now implement this in my competition notebook and let you know if we make progress.
I see the NaN now - you are right, this occurs when a batch all belongs to the same class. I'm also noticing a second issue with this code's implementation; it is missing a step where it should do a random subsample.
Give me a few minutes to patch ...
Taking a bit more. I'll get this patched up within the next 6-12 hrs for sure. The basic issue is this code has drifted a bit from the original code here https://github.com/iridiumblue/roc-star/blob/master/example.py . This code introduced a few errors.
There are 3 easy and 1 less easy fixes I'm doing
1 - stomp out NaN at the end of forward() using
Kick NaN's to the curb.
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 2 - increase the sample sizes (all 3 of them) in the constructor to ~ 1000. (This isn't batch size or anything, it's an internal value that needs to be about that size for the purposes of a certain tensor operation. That subsample keeps the tensor within a reasonable memory size for the GPU. 3 - add a method to call update_gamma at the end of each epoch. 4 - Change a couple lines of code from slicing to random subsampling.
Stay tuned ... If you don't need a class, you can use the old non-class code (see example.py) that works correctly. Or just wait a bit more :)
I will wait for your implementation of the class no worries.
Currently, I had 0.2 to 0.3 improvement in AUC and significant improvement in ACC. And I followed a bit of your implementation where I train for 2 epochs on BCE Loss and then using the above loss I train for a few more epochs.
Thank you for your work @iridiumblue and @zakajd for ROC-STAR and this class implementation! I would love to try this loss function, but for my workflow it would be easier if it could be used as a class. Looking forward to any progress!
I got dragged off into other things, let me try to get this wrapped for you today!
More to come ...
Hey, is there any semi-fixed code for the class?
I will wait for your implementation of the class no worries.
Currently, I had 0.2 to 0.3 improvement in AUC and significant improvement in ACC. And I followed a bit of your implementation where I train for 2 epochs on BCE Loss and then using the above loss I train for a few more epochs.
I got worse performance... how did you get improved?
Taking a bit more. I'll get this patched up within the next 6-12 hrs for sure. The basic issue is this code has drifted a bit from the original code here https://github.com/iridiumblue/roc-star/blob/master/example.py . This code introduced a few errors.
There are 3 easy and 1 less easy fixes I'm doing 1 - stomp out NaN at the end of forward() using
Kick NaN's to the curb.
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 2 - increase the sample sizes (all 3 of them) in the constructor to ~ 1000. (This isn't batch size or anything, it's an internal value that needs to be about that size for the purposes of a certain tensor operation. That subsample keeps the tensor within a reasonable memory size for the GPU. 3 - add a method to call update_gamma at the end of each epoch. 4 - Change a couple lines of code from slicing to random subsampling.
Stay tuned ... If you don't need a class, you can use the old non-class code (see example.py) that works correctly. Or just wait a bit more :)
Do we really need random subsampling instead of slicing if the dataset is shuffled? Aren't they doing basically the same thing?
Don't know if this is still getting any updates, but I just wanted to point out that the class implementation above (https://github.com/iridiumblue/roc-star/issues/2#issuecomment-709645076) is at risk of data leakage if used during validation/testing because every prediction+label that passes through it is saved temporarily and used in subsequent loss calculations.
That means the samples at the end of the validation set will influence the loss at the beginning of the next training epoch.
I wanna thanks the author and others for this great idea! Does anyone have a bug-free class implementation now ? Thanks!
Shouldn't it be relu(diff) ** 2 instead of relu(diff ** 2)?