COMET icon indicating copy to clipboard operation
COMET copied to clipboard

Cannot use `load_from_checkpoint` in an offline environment

Open zzaebok opened this issue 7 months ago • 3 comments

Motivation

Due to the policy of my company, I cannot access to an online server (e.g., huggingface). So I tried to use load_from_checkpoint function after I manually downloaded wmt22-comet-da checkpoint. However, it makes an SSL error "SSLError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded".

This error is raised because of missing local_files_only param to each model's from_pretrained method.

🚀 Feature

I think COMET needs to support local_files_only parameter to load_from_checkpoint function not to raise http connection error.

Below is my current solution to the problem.

  • Add local_files_only param to load_from_checkpoint.
  • Pass it to str2model value classes' load_from_checkpoint(https://github.com/Unbabel/COMET/blob/master/comet/models/init.py#L88)
    • it automatically passes the param to model instantiation through LightningModule's load_from_checkpoint using kwargs.
  • Add local_files_only param to CometModel.__init__() (https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L94)
  • Pass it to str2encoder value classes from_pretrained (https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L119)
  • Add local_files_only param to each encoder classes' __init__ such as
self.model = XLMRobertaModel(
    XLMRobertaConfig.from_pretrained(pretrained_model, local_files_only=local_files_only),
    add_pooling_layer=False,
)

Can I make a PR? Please let me know if there is better way.

zzaebok avatar Jul 26 '24 07:07 zzaebok