CompressAI icon indicating copy to clipboard operation
CompressAI copied to clipboard

Bug when I try to evaluate a model on my own dataset

Open AlbertoPresta opened this issue 2 years ago • 2 comments

Hi,

I think I found a bug when I run the following command (suggested by you):

python3 -m compressai.utils.eval_model checkpoint /path/to/images/folder/ -a $ARCH -p $MODEL_CHECKPOINT...

The bug is the following:

Traceback (most recent call last): File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 310, in main(sys.argv[1:]) File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 286, in main model = load_func(*opts, run) File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 150, in load_checkpoint return architectures[arch].from_state_dict(state_dict).eval() File "/opt/conda/lib/python3.7/site-packages/compressai/models/google.py", line 157, in from_state_dict N = state_dict["g_a.0.weight"].size(0) KeyError: 'g_a.0.weight'

I think also that the error comes from the fact that you should pass the actual state_dict of the net, which is state_dict["state_dict"], not only state_dict; in my opinion, we should have something like:

N = state_dict["state_dict"]["g_a.0.weight"].size(0)

Maybe I miss something in my previous command.

Alberto

AlbertoPresta avatar Jun 09 '22 12:06 AlbertoPresta

This error usually occurs when the model hasn't been updated via compressai.utils.update_model after training. See here for an example. Currently, this removes one layer of the "state_dict" (i.e. ckpt <- ckpt["state_dict"]) and also updates the CDFs. (Note: if you want to do further training, please create a copy of the checkpoint as a backup before running update_model.)


An (unofficial) alternative that I personally use is just to call model.update(force=True) within load_checkpoint itself, and to remove the extra layer of "state_dict" when loading the checkpoint:

# compressai/utils/eval_model/__main__.py

def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
    ckpt = torch.load(checkpoint_path)
    state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    state_dict = load_state_dict(state_dict)  # for pre-trained models
    model = architectures[arch].from_state_dict(state_dict).eval()
    model.update(force=True)
    return model

YodaEmbedding avatar Jun 09 '22 19:06 YodaEmbedding

I download a pretrained model with the following command:

net = bmshj2018_factorized(quality=8, pretrained=True).eval().to(torch.device("cpu")) net.update(force=True)

The I try to print out the quantized cdf and I have something like this (I show only the first dimension):

quantized_cdf = Tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 65536, 0, ...,0]) tail_mass = 1.

Moreover, all dimensions have tail_mass equal to 1...what does it mean? should it equal to (or similar to) 1e-9? what is the domain of the quantized cdf if I want to plot it?

thanks in advance for the answers. Alberto

AlbertoPresta avatar Jun 20 '22 12:06 AlbertoPresta