tutel
tutel copied to clipboard
about compute_location and locations
Thanks for your excellent work of tutel.
I would like to know the function's function(at fast_dispatch.py)
def compute_sorted_location(x, importance_scores):
sorted_x = x[importance_scores.argsort(dim=0)]
sorted_cumsum = fast_cumsum_sub_one(sorted_x) * sorted_x
return sorted_cumsum[importance_scores.argsort(dim=0).argsort(dim=0)]
and the meaning of the parameters locations_s which is return value of the function extract_critical(at fast_dispatch.py too)
It stores a list of unique index destinations that input tokens are to be written on for the following dispatching.