Sampler gets all tokens
the Sampler compute in this project, need all the generated tokens
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
def forward(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
return_logprob: bool,
top_logprobs_nums: List[int],
):
So, how do I modify the code here to get all the tokens that have been generated Thanks
You’ll need to modify the forward function so that—as it generates each token—it saves it (for example, by appending it to a list) and then returns that list. In many cases the sampling loop is where the token‐by‐token generation happens.
I added decode_token_list: list = field(default_factory=list) to class SamplingBatchInfo: Perfect. It should be working
This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.