badge icon indicating copy to clipboard operation
badge copied to clipboard

Entropy sampling issue

Open ElvinKim opened this issue 2 years ago • 0 comments

There is an error in your implementation of entropy sampling.

The following is an example and the result is when using implemented entropy sampling.

probs = [[0.1, 0.2, 0.7], [0.0, 0.1, 0.9], [0.0, 0.5, 0.5]]
probs = torch.tensor(probs)

log_probs = torch.log(probs)

U = (probs*log_probs).sum(1)

print(log_probs)
print(U)
---------
-- Print Result -- 
tensor([[-2.3026, -1.6094, -0.3567],
        [   -inf, -2.3026, -0.1054],
        [   -inf, -0.6931, -0.6931]])
tensor([-0.8018,     nan,     nan])
------------

If this issue is not addressed, the results of your implemented entropy sampling will inevitably be much worse than margin sampling.

This can be solved simply by adding the following code:

log_probs = torch.log(probs)

log_probs[log_probs == float("-inf")] = 0
log_probs[log_probs == float("inf")] = 0

ElvinKim avatar Jul 26 '22 09:07 ElvinKim