AutoTCL
AutoTCL copied to clipboard
Questions about Mask h (x)
Hello author,I have a question that I seem to have not found the position of the mask h in get_features to decompose the time series x into the informative part x *. Could you please let me know.I see that it seems that x has been directly enhanced to form ax. def get_features(self, x, training = True, n_epochs=-1, mask=None):
embedding = self.augsharenet(x)
weight_h = self.factor_augnet(embedding)
weight_s = self.augmentation_projector(embedding)
mask_h = self._sample_graph(weight_h,training= training) #输出(32,257,8)
if self.hard_mask:
hard_mask_h = (torch.sign(mask_h-0.5)+1)/2
# print(hard_mask_h)
mask_h = (mask_h-hard_mask_h).detach()+hard_mask_h
mask_h = (hard_mask_h - mask_h).detach()+mask_h
ax = weight_s * mask_h * x # augmented x'
if torch.isnan(ax).any() or torch.isnan(x).any():
exit(1)
# note: I add mask
out1 = self._net(x,mask) # representation
out2 = self._net(ax,mask) # representation of augmented x'
return x, ax, out1, out2, weight_h
The informative part is mask_h*x. It is part of " ax = weight_s * mask_h * x " Thanks.
I understand, thank you!!