transformers
transformers copied to clipboard
Custom model building missing key component not allowing easy access to .generate methods
System Info
4.39.3 linux 3.12
Who can help?
@gante
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
- Create a custom model according to the guidelines set here
- Overwrite the inherited method
prepare_inputs_for_generation - Call
model.generatein order to access different generation methods
AttributeError: 'dict' object has no attribute 'logits'
In my case the model.forward(src=None, src_mask=None, tgt=None, tgt_mask=None, **kwargs) has the following arguments
So, in def prepare_inputs_for_generation() method I"m setting the decoder_input_ids=model.config.decoder_start_token_id and return all the other *args and **kwargs.
Can anyone please provide a MWE example of a custom architecture subclassing PreTrainedModel with custom keyword arguments in the .forward() that could allow the extended model to have easy access to any of the generation_mode methods by calling model.generate
I think this should clear and evident in the docs so that we don't have to re-invent the wheel every time. Lot's of people looking for how to accomplish this.
Expected behavior
Hopping that things might work after following official docs for subclassing and creating a custom model.
Hey @kirk86 ! Thanks for pointing out this issue. We are in the process of making generation more generalizable and easy to integrate with different model types. But in the meanwhile, I can offer a workaround solution to this.
From what I see looks like your custom model is returning a dict, while the generate expects a dataclass type output (see here that return_dict=True is passed). Tthe custom model's forward needs to return a dataclass object with logits when return_dict=True, so that it can be accessible via "output.logits" in the generation.
@dataclass
class MyOutput:
logits: torch.Tensor = None
loss: torch.Tensor = None
def forward(self, tensor, labels=None):
logits = self.model(tensor)
loss = None
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
if return_dict:
return MyOutput(logits=logits, loss=loss)
return {"logits": logits, "loss": loss}
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.