Styleformer
Styleformer copied to clipboard
Update styleformer.py
Add torch_dtype=torch.float32
because "RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'"
self.ctf_model = AutoModelForSeq2SeqLM.from_pretrained(ctf_model_tag, use_auth_token=False, torch_dtype=torch.float32)
self.ftc_model = AutoModelForSeq2SeqLM.from_pretrained(ftc_model_tag, use_auth_token=False, torch_dtype=torch.float32)
self.atp_model = AutoModelForSeq2SeqLM.from_pretrained(atp_model_tag, use_auth_token=False, torch_dtype=torch.float32)
self.pta_model = AutoModelForSeq2SeqLM.from_pretrained(pta_model_tag, use_auth_token=False, torch_dtype=torch.float32)
I also increased max_length
from 32 to 128.