FasterTransformer icon indicating copy to clipboard operation
FasterTransformer copied to clipboard

Adding length penalty in v5.0 of online_softmax_beamsearch_kernels

Open HossamAmer12 opened this issue 3 years ago • 3 comments

Hi there, I would like to multiply the sum of log probs by a length_penalty as applied to the most recent version of T.

I am using an older version (v5.0) of the FasterTransformer repo: https://github.com/NVIDIA/FasterTransformer/commit/952a1f2449f6bdd921feaba7cea68ae95d0d426b

I know that the code is in this file: src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu And the newest version of the code does it here (v5.2): https://github.com/NVIDIA/FasterTransformer/blob/bc214067d603bf95beb8af5d1ce5960e96ba7244/src/fastertransformer/kernels/online_softmax_beamsearch_kernels.cu

Can you please point out where/how I can add the length penalty logic in v5.0 of the online_softmax_beamsearch_kernels linked above? Just adding the length penalty logic.

Any guidance would be appreciated.

HossamAmer12 avatar Oct 03 '22 01:10 HossamAmer12

If you can explain what the TopKMD is doing in the old code, that'd be greatly appreciated.

HossamAmer12 avatar Oct 03 '22 01:10 HossamAmer12

Can you please point out where/how I can add the length penalty logic in v5.0 of the online_softmax_beamsearch_kernels linked above? Just adding the length penalty logic.

The online_softmax_beamsearch_kernels.cu of these two commits are almost same, you can compare them directly.

If you can explain what the TopKMD is doing in the old code, that'd be greatly appreciated.

It is a structure to help storing the topk elements and have custom sorting with index storing.

byshiue avatar Oct 04 '22 03:10 byshiue

Thanks, @byshiue.

Just one question, beam_online_softmax_topk_stage1_kernel depends on the vocab size parameter to move to the right memory location of log probs. Does beam_online_softmax_topk_stage2_kernelLauncher and/or batch_topk_kernel also depend on the vocab size parameter in your implementation of v5.0/5.2? And where is the part you calculate the cumlog probs? In stage2 and stage3, I was not able to point out where you accumulate the sum.

Can you please help by pointing them out?

HossamAmer12 avatar Oct 04 '22 22:10 HossamAmer12