ttt-plus-plus
ttt-plus-plus copied to clipboard
covariance function in cifar/discrepancy.py seems strange
I'm confused of how the covariance is calculated in the code, the result of covariance function is different from what np.cov gives. I'm sorry if I have misunderstood.
Also, the ttt++.py as shown from line 286-289:
- get queue
- update queue's feature with feat_ext
- concat feat_ext with queue's feature.
However, feat_ext in the second step will be used twice when calculating covariance in line 292. A toy sample is shown as below.
Thanks for your interest.
My torch implementation of covariance is based on the Eq~2 of the Deep Coral.
From my quick test, it produces the same results as the official NumPy API (fig below).
Could you please verify the test case on your end, and share your findings if you find anything incorrect?
Thanks, Yuejiang