SAN
SAN copied to clipboard
Problems about computing the loss
Hi, thank you for your great effort! I feel confused about the loss computation.
- For the Lcls loss, is the ground truth the class label of each image?
- For the Lmask_dice and Lmask_bce, is the ground truth the binary mask without class information?
- why did you generate the point labels? It shows as follows:
with torch.no_grad():
# sample point_coords
point_coords = get_uncertain_point_coords_with_randomness(
src_masks,
lambda logits: calculate_uncertainty(logits),
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
# get gt labels
point_labels = point_sample(
target_masks,
point_coords,
align_corners=False,
).squeeze(1)
- You used prompt engineering. I read the code and found that you use the template to generate ov_classifier_weight in san.py lines 161-164 during training. Have you used prompt engineering anywhere else?
- For inference, where did you use prompt engineering and class information? Is the class information used as known information for a single picture when inferring?
- Can this work be considered as zero-shot semantic segmentation and Why?
These problems confused me a lot! I'm looking forward to your reply.