TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

Fix args for `decode_batch` in Blip2

Open Lucius-THU opened this issue 1 year ago • 0 comments

Issue: In examples/blip2/run.py, there was an issue with the way arguments were passed to the decode_batch method. This led to a TypeError, with the specific error message: GenerationSession.decode_batch() takes 3 to 4 positional arguments but 6 were given.

Solution: To address this issue, I made the following changes:

  • Modified ptuning_args from a list to a dict.
  • Updated the call to the decode_batch method to use **ptuning_args for passing arguments.

Expected Impact: This update will enable the decode_batch method call in examples/blip2/run.py to function correctly.

Testing: Prior to submitting this PR, I have tested these changes in a local environment. The modified code ran successfully without encountering the previous TypeError, and the decode_batch method functioned as expected.

Lucius-THU avatar Feb 01 '24 02:02 Lucius-THU