badge
badge copied to clipboard
Entropy sampling issue
There is an error in your implementation of entropy sampling.
The following is an example and the result is when using implemented entropy sampling.
probs = [[0.1, 0.2, 0.7], [0.0, 0.1, 0.9], [0.0, 0.5, 0.5]]
probs = torch.tensor(probs)
log_probs = torch.log(probs)
U = (probs*log_probs).sum(1)
print(log_probs)
print(U)
---------
-- Print Result --
tensor([[-2.3026, -1.6094, -0.3567],
[ -inf, -2.3026, -0.1054],
[ -inf, -0.6931, -0.6931]])
tensor([-0.8018, nan, nan])
------------
If this issue is not addressed, the results of your implemented entropy sampling will inevitably be much worse than margin sampling.
This can be solved simply by adding the following code:
log_probs = torch.log(probs)
log_probs[log_probs == float("-inf")] = 0
log_probs[log_probs == float("inf")] = 0