Styleformer icon indicating copy to clipboard operation
Styleformer copied to clipboard

Update styleformer.py

Open SuperBruceJia opened this issue 1 year ago • 0 comments

Add torch_dtype=torch.float32 because "RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'"

self.ctf_model = AutoModelForSeq2SeqLM.from_pretrained(ctf_model_tag, use_auth_token=False, torch_dtype=torch.float32)
self.ftc_model = AutoModelForSeq2SeqLM.from_pretrained(ftc_model_tag, use_auth_token=False, torch_dtype=torch.float32)
self.atp_model = AutoModelForSeq2SeqLM.from_pretrained(atp_model_tag, use_auth_token=False, torch_dtype=torch.float32)
self.pta_model = AutoModelForSeq2SeqLM.from_pretrained(pta_model_tag, use_auth_token=False, torch_dtype=torch.float32)

I also increased max_length from 32 to 128.

SuperBruceJia avatar Dec 12 '23 03:12 SuperBruceJia