ponder-transformer
ponder-transformer copied to clipboard
Evaluating ponder-net on more pondering-steps than trained on.
As the paper says,
In evaluation, and under known temporal or computational limitations, N can be set naively as a constant (or not set any limit, i.e. N → ∞). For training, we found that a more effective (and interpretable) way of parameterizing N is by defining a minimum cumulative probability of halting. N is then the smallest value of n such that sum( p_sub_ j > 1 − ε)over(j=1, n) , with the hyper-parameter ε positive near 0 (in our experiments 0.05).
from that I infer that pondering can be done to more steps than trained on. How can be done so with this implementation?
edit: I was going through the paper again,and I think what the paper means is that the max_num_pondering_steps:N should be re evaluated at every training-step, the model should be run till the condition is met or a pre-defined num of max steps is reached, and where the cumsum_probs condition will be met will be set as 'N', with the cumsum_probs normalised with one of the methods. Then that value of 'N' will be used to calc prior geom for the kl_div (and not normalising the prior geom term).
i.e. if the num of pondering steps are initially set to 'M', then the model will recur for 'k' steps - i.e. till the condition is met or for 'M' num of max steps; then 'N' will be calculated by first calculating the probabilities - p_0 to p_k - then normalizing through one of the methods, then calculate cumulative-sum of those probabilities, and checking where the sum is greater than threshold, and assigning it the value 'N'. After that, calculating prior geometric values with the defined hyper-parameter, for 'N' seq-len, and using this in the kl-div term against the halting probs truncated to 'N' steps.
λp is a hyper-parameter that defines a geometric prior distribution pG(λp) on the halting policy (truncated at N)