torchsurv
torchsurv copied to clipboard
Neg partial log likelihood loss is 0 every time the batch size is 1
Hi there,
Thanks for sharing this wonderful library!
I was trying to run a survival analysis using the Cox proportional hazards model and due to the GPU constraints, I have to go with the batch size of 1. And every time I run the model, I observe that the loss value is always 0 when I'm using cox.neg_partial_log_likelihood
.
I looked into the implementation of the _partial_likelihood_cox
and it seems that the log_denominator
gets the same value as the log_hz_sorted
when the batch size is 1, resulting in the loss to be 0.
I was wondering if there is a workaround for this issue, please let me know. Also attaching the link to the corresponding code https://github.com/Novartis/torchsurv/blob/799eb30f4b123f4871659ffcfbb7c914d1f67fcd/src/torchsurv/loss/cox.py#L174
Thank you in advance!