ort icon indicating copy to clipboard operation
ort copied to clipboard

topKgate loss issues

Open powermano opened this issue 1 year ago • 0 comments

We have calculated the loss of the gate, but does this have any effect on training? Where is this Loss used?

 logits = self.wg(input) #dim: [bxs, num_experts]
        if self.k == 1:
            self.loss, self.gate_log, gates1_s, dispatch_mask, retval = top1gating(
                    logits,
                    self.capacity_factor if self.training else self.eval_capacity_factor,
                    is_expert_slicing=self.is_expert_slicing,
                    fp16_mode=self.fp16_mode,
                    nonpadding=nonpadding,
                    logits_gumbel=self.logits_gumbel if self.training else 0,
                    token_drop_type=self.token_drop_type,
                    straight_through=self.straight_through,
                    straight_through_temperature=self.straight_through_temperature,
                    balance_ratio=self.balance_ratio,
                    gate_log_req=self.gate_log_req,
                    lid=lid,
                    tutel_cumsum_sub_one=self.tutel_cumsum_sub_one,
                )
            return gates1_s, dispatch_mask, retval

powermano avatar Jul 31 '23 11:07 powermano