COMET
COMET copied to clipboard
Cannot use `load_from_checkpoint` in an offline environment
Motivation
Due to the policy of my company, I cannot access to an online server (e.g., huggingface).
So I tried to use load_from_checkpoint
function after I manually downloaded wmt22-comet-da
checkpoint.
However, it makes an SSL error "SSLError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded".
This error is raised because of missing local_files_only
param to each model's from_pretrained
method.
🚀 Feature
I think COMET needs to support local_files_only
parameter to load_from_checkpoint
function not to raise http connection error.
Below is my current solution to the problem.
- Add
local_files_only
param toload_from_checkpoint
. - Pass it to str2model value classes'
load_from_checkpoint
(https://github.com/Unbabel/COMET/blob/master/comet/models/init.py#L88)- it automatically passes the param to model instantiation through
LightningModule
'sload_from_checkpoint
usingkwargs
.
- it automatically passes the param to model instantiation through
- Add
local_files_only
param toCometModel.__init__()
(https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L94) - Pass it to str2encoder value classes
from_pretrained
(https://github.com/Unbabel/COMET/blob/master/comet/models/base.py#L119) - Add
local_files_only
param to each encoder classes'__init__
such as
self.model = XLMRobertaModel(
XLMRobertaConfig.from_pretrained(pretrained_model, local_files_only=local_files_only),
add_pooling_layer=False,
)
Can I make a PR? Please let me know if there is better way.