MAML-Pytorch icon indicating copy to clipboard operation
MAML-Pytorch copied to clipboard

Incorrect losses_q

Open zilin129 opened this issue 4 years ago • 3 comments

losses_q is supposed to be all losses on query set across all tasks. But it is defined as below. That's all the losses on query set for all the update_step for task i, not all the task. losses_q = [0 for _ in range(self.update_step + 1)] Is this an error?

zilin129 avatar Mar 26 '21 22:03 zilin129

losses_q is supposed to be all losses on query set across all tasks. But it is defined as below. That's all the losses on query set for all the update_step for task i, not all the task. losses_q = [0 for _ in range(self.update_step + 1)] Is this an error?

so, you are suggesting losses_q = [0 for _ in range(self.update_step*task_num + 1)] to accumulate all query loss?

NookLook2014 avatar Apr 07 '21 08:04 NookLook2014

No it is not an error. They create a list the length of the number of inner update step then accumulate the loss on each task at that particular step i.e losses_q[3] will be the sum of loss of each task for the update step 4. He then proceed to take only the final value of the list and to average on that is done with loss_q = losses_q[-1] / task_num on line 134 .

tim-hash avatar Aug 12 '21 17:08 tim-hash

No it is not an error. +=. [-1]

woundenfish avatar Aug 13 '21 13:08 woundenfish