lightning-flash
lightning-flash copied to clipboard
Support generation kwargs within Seq2SeqTasks
🚀 Feature
Seq2Seq tasks tasks (and tasks that inherit from it like SummarizationTask
) only allow a user to specify a couple of arguments to model.generate
https://github.com/Lightning-AI/lightning-flash/blob/651e85851509fd04f723caedfef8d487d77df4e0/flash/text/seq2seq/core/model.py#L139-L144
however, the generate
method from HF supports a ton of arguments and decoding strategies, specified by a generation_config
. A lot of flexibility could be unlocked by allowing Seq2SeqTask
to accept a generation_config
.
Motivation
Seq2SeqTask
appears to be the main interface to text generation within Flash. It would really open up a lot of flexibility for this class of tasks if a user could easily specify the decoding strategy.
Pitch
I think the change is quite straightforward:
- Update
Seq2SeqTask
to accept a new argument,generation_config
matching the HuggingFace object - Remove any arguments to
Seq2SeqTask
covered by this config (e.g.num_beams
) - Update
Seq2SeqTask.forward
so that it provides this config tomodel.generate
Alternatives
I believe something similar could be achieved by adding a new argument, generation_kwargs
, which, similar to the above strategy would be provided to Seq2SeqTask
and passed as **generation_kwargs
to model.generate
via Seq2SeqTask.forward
.
Additional context
Would be happy to work on a PR if the maintainers agree!