SEGA
SEGA copied to clipboard
how to split base class in stage2 for 01_miniimagenet_stage2_SEGA_5W1S?
when I debug 01_miniimagenet_stage2_SEGA_5W1S, I get the following result : 【traincode.py --->>> def train_stage2(opt): 】
Knovel_ids.size() torch.Size([8, 5]) Kbase_ids.size() torch.Size([8, 59])
logit_query.size() torch.Size([8, 60, 64])
It seems 64 base classes are divided into 59 Kbase and 5 Knovel? And it does 64-way classification?
Could you please give some more details about these results? Thanks!
Hi, I think the explanation you are looking for can be found at the end of 'Training Procedure' in '3.3. Framework' of our paper:
More specifically, for each episode, we randomly sample N classes from the base classes Yb to act as “novel” classes, then sample K samples from each “novel” class to form a fake N-Way K-Shot support set. As shown in Figure 3, we can calculate N visual prototypes and enhance them using semantic guided attentions. Thus we get N classification weights which are used to replace the corresponding base classification weights (other weights are also enhanced by their own semantic attentions) in Cosine Classifier, and then perform classification and cross-entropy loss calculation.
In short, yes, we are always doing the 64-way classification (sample 5 classes to act as “novel” classes to generate their classification weights, and the weights of other 59 "base" classes are from training parameters of SEGAhead.weight_base just like the first stage of training). By the way, this training strategy is from Dynamic-FSL.