ganbert icon indicating copy to clipboard operation
ganbert copied to clipboard

Training on labeled only can achieve much better results than what is reported

Open IssamLaradji opened this issue 2 years ago • 3 comments

Interesting idea of using an adversarial method for leveraging unlabeled data. I am trying to see how much unlabeled data can actually help.

In the plot below, I am comparing GanBert (Orange) that trains on both labeled and unlabeled data, and a basic model that uses Bert+Classifier (blue) that trains on the 109 labeled data only of Trec Data.

The paper reports that the basic model should achieve around 40%, but I am getting 60% which is very close to GanBert's. Are you sure that the baseline discussed in the paper is a reasonable one?

image

IssamLaradji avatar Feb 18 '22 05:02 IssamLaradji

Dear IssamLaradji

first, thank you for considering GANBERT

In our experiments, we just used the Tensorflow implementation of BERT that was available on the official repository (https://github.com/google-research/bert) during the early months of 2020. The results may depend on different factors, as the sampling of examples.

If you are interested in advanced parameters that you can use in more recent implementations (e.g., Huggingface, where, for example, you can use schedulers) you can take a look at https://github.com/crux82/ganbert-pytorch

Hope it helps

Bests

Danilo

crux82 avatar Feb 18 '22 09:02 crux82

Thanks for your reply! The implementation is the same as yours except that I used the discriminator as the classifier on top of "bert-base-cased" and removed the generator part, see training loop below for the base model. Have you tried BERT Transformer + Discriminator as your classifier?

for batch in tqdm.tqdm(train_dataloader, desc="Training"):
            # Unpack this training batch from our dataloader.
            b_input_ids = batch["input_ids"].to(device)
            b_input_mask = batch["input_masks"].to(device)
            b_labels = batch["label_ids"].to(device)

            # Encode real data in the Transformer
            model_outputs = self.transformer(b_input_ids, attention_mask=b_input_mask)
            hidden_states = model_outputs[-1]
            
            # Get logits 
            _, logits, _ = self.discriminator(hidden_states)
            filtered_logits = logits[:, 0:-1]
            
            # update optimizer
            self.dis_optimizer.zero_grad()
            d_loss = nll_loss(filtered_logits, b_labels)
            d_loss.backward()
            self.dis_optimizer.step()

IssamLaradji avatar Feb 18 '22 15:02 IssamLaradji

Hi!

I've obtained similar results to @IssamLaradji!

@crux82 did you consider that increasing training epochs number for BERT model can be better in terms of quality metric than applying whole GAN-BERT thing? Could you please elaborate on why you decided to compare BERT and GAN-BERT after the same number of training epochs?

Thanks in advance!

VirtualRoyalty avatar Feb 13 '23 23:02 VirtualRoyalty