ort
ort copied to clipboard
topKgate loss issues
We have calculated the loss of the gate, but does this have any effect on training? Where is this Loss used?
logits = self.wg(input) #dim: [bxs, num_experts]
if self.k == 1:
self.loss, self.gate_log, gates1_s, dispatch_mask, retval = top1gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
is_expert_slicing=self.is_expert_slicing,
fp16_mode=self.fp16_mode,
nonpadding=nonpadding,
logits_gumbel=self.logits_gumbel if self.training else 0,
token_drop_type=self.token_drop_type,
straight_through=self.straight_through,
straight_through_temperature=self.straight_through_temperature,
balance_ratio=self.balance_ratio,
gate_log_req=self.gate_log_req,
lid=lid,
tutel_cumsum_sub_one=self.tutel_cumsum_sub_one,
)
return gates1_s, dispatch_mask, retval