Question about topk_index = topk_index.view(-1) in cnets.py
Dear Authors, Thank you for your outstanding work on EAGLE and EAGLE-2! I have a question regarding the implementation, specifically about the line topk_index = topk_index.view(-1). Let me illustrate my confusion with an example: When executing topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k), with top_k=3, we obtain 4 branches: [ [55, 67, 33], [25578, 13, 45], [2, 56, 42], [32, 64, 89] ] After applying topk_index = topk_index.view(-1), the output becomes: [55, 67, 33, 25578, 13, 45, 2, 56, 42, 32, 64, 89]
Then, if self.tree_buffer['tree_indices'][i] = [0, 2], we get select_index = [55, 33]. Is my understanding correct? Initially, I expected the implementation to select tokens from each branch. However, it appears we're selecting tokens from specific positions in the flattened array. Is there a particular reason for ignoring certain branches? Or is this a deliberate design choice in the tree buffer indices? I would greatly appreciate ur time! Regards,