nboost icon indicating copy to clipboard operation
nboost copied to clipboard

Train model using domain data

Open Sharathmk99 opened this issue 5 years ago • 14 comments

Hi, Thank you for amazing framework. I'm planning to use this framework for question answer system. Can you please help me how can I train model with my own domain data? I know that nboost doesn't support training, but can you help me how to train outside nboost and configure to nboost later.

Thank you.

Sharathmk99 avatar Feb 16 '20 23:02 Sharathmk99

Do you mean train a reranking model? Or a model to extract answer spans from passages?

In the former case I use the train triples from say MS Marco which are of format: <query> <relevant passage example> <not relevant passage example> turn that into data like:

<query> <relevant passage example> label = 1
<query> <irrelevant passage example> label = 0

and train a binary classification model. Here is sample code in the format of the original BERT release (which is super old dont use this)

class MSMarcoProcessor(DataProcessor):

  def _create_example(self, query, doc, label, set_type, i):
    guid = "%s-%s" % (set_type, i)
    text_a = tokenization.convert_to_unicode(query).lower()
    text_b = tokenization.convert_to_unicode(doc).lower()
    if set_type == "test":
      label = "0"
    else:
      label = tokenization.convert_to_unicode(label).lower()
    return InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)

  def get_train_examples(self, data_dir):
    """See base class."""
    print('Converting to Train to tfrecord...')

    train_dataset_path = os.path.join(data_dir, 'triples.train.small.tsv')

    print('Counting number of examples...')
    num_lines = 100000 # sum(1 for line in open(train_dataset_path, 'r'))
    print('{} examples found.'.format(num_lines))
    examples = []

    with open(train_dataset_path, 'r') as f:
      for i, line in enumerate(f):
        if i > 2000000:
          break
        query, positive_doc, negative_doc = line.rstrip().split('\t')
        examples.append(self._create_example(query, positive_doc, str(1), 'train', i))
        examples.append(self._create_example(query, negative_doc, str(0), 'train', i + 0.5))
    return examples

  def get_dev_examples(self, data_dir):
    """See base class."""
    return []

  def get_test_examples(self, data_dir):
    """See base class."""
    return self.get_dev_examples(data_dir)

  def get_labels(self):
    """See base class."""
    return ["0", "1"]

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      if i == 0:
        continue
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[3]).lower()
      text_b = tokenization.convert_to_unicode(line[4]).lower()
      if set_type == "test":
        guid = line[0]
        label = "0"
      else:
        label = tokenization.convert_to_unicode(line[0]).lower()
      examples.append(
        InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
    return examples

pertschuk avatar Feb 28 '20 09:02 pertschuk

@pertschuk retraining of passage re-ranking model. Ok i'll try to train binary classification model. If i face any problem i'll post here.

Sharathmk99 avatar Feb 29 '20 20:02 Sharathmk99

@Sharathmk99 I would like to know how does the results look for the question answering system after training the model? please post your findings.

pidugusundeep avatar Mar 05 '20 14:03 pidugusundeep

@pidugusundeep we have trained binary classification model using hugging face transformers using sample data. It works fine. But our goal is to use transfer learning of bert-tiny model and add our own domain data. I'll keep you posted. Thank you

Sharathmk99 avatar Mar 05 '20 17:03 Sharathmk99

@pertschuk We have trained a binary classification model using sample data as below.

We have used BERT tokenizer to convert the input examples to features to feed the data for training. Created one more layer for tuning. Saved the checkpoints after training.This is implemented in Tensorflow. The input sample is as below:

id | query | passage | label 1 | 'Query Text' | 'Relevant Passage' | 1 2 | 'Query Text' | 'Irrelevant Passage' | 0

Could you please guide us on how to add those checkpoints to implement in nboost framework?

Thank you!

jishapjoseph avatar Mar 09 '20 08:03 jishapjoseph

@jishajoseph what version of tensor flow?

pertschuk avatar Mar 09 '20 17:03 pertschuk

@pertschuk The tensorflow version is 1.15.0

jishapjoseph avatar Mar 10 '20 04:03 jishapjoseph

@jishajoseph try running nboost --model_dir <path_to_your_model> --model TfBertRerankModelPlugin class Tfv1BertRerankModelPlugin for tensorflow 1.x versions.. the model path should be a directory containing a .ckpt file, bert_config.json, and vocab.txt

pertschuk avatar Mar 10 '20 05:03 pertschuk

@pertschuk Sure..I will try doing the above and I will let you know.

Thank you.

jishapjoseph avatar Mar 10 '20 06:03 jishapjoseph

@jishajoseph sorry flag above should be --model . Running nboost --help shows all available flags

pertschuk avatar Mar 10 '20 07:03 pertschuk

@pertschuk there is no such plugin present - TfBertRerankModelPlugin

mayankchatteron1 avatar May 06 '20 07:05 mayankchatteron1

@pommedeterresautee @Sharathmk99

I just pushed a clean and up-to-date example how to train BERT and other transformer model on MS Marco: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_cross-encoder.py

The result is a HuggingFace Transformer model that could be used with nboost. The models I trained so far outperform the nboost models with comparable model size & run-time (they will soon be added to Hugginface models repository).

nreimers avatar Nov 20 '20 16:11 nreimers

https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_cross-encoder.py

the code here uses num_label=1, which is not compatible with nboost

wutianhao1973 avatar Feb 10 '21 13:02 wutianhao1973

You can just set it to num_labels=2 and use it without further changes.

nreimers avatar Feb 10 '21 13:02 nreimers